Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

163
csrc/cpu/activation.cpp Normal file
View File

@@ -0,0 +1,163 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t* __restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(d % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
int start = i * d;
if constexpr (is_gated) {
start *= 2;
}
const scalar_vec_t x(input + start + j);
const vec_op::FP32Vec8 f32_x(x);
vec_op::FP32Vec8 f32_ans = func(f32_x);
if constexpr (is_gated) {
const scalar_vec_t y(input + start + d + j);
const vec_op::FP32Vec8 f32_y(y);
f32_ans = f32_y * f32_ans;
}
const scalar_vec_t result(f32_ans);
result.save(output + i * d + j);
}
}
}
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 x3 = x * x * x;
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(1.702f);
return x / (ones + (zeros - w1 * x).exp());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5);
const vec_op::FP32Vec8 w3(0.044715);
const vec_op::FP32Vec8 x_3 = x * x * x;
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh());
}
}; // namespace
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
}
void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
}
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
activation_kernel<scalar_t, gelu_tanh_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
});
}
void gelu_new(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_new_impl)
activation_kernel<scalar_t, gelu_new_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
});
}
void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
activation_kernel<scalar_t, gelu_fast_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
});
}
void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_quick_impl)
activation_kernel<scalar_t, gelu_quick_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
});
}

265
csrc/cpu/cpu_attn.cpp Normal file
View File

@@ -0,0 +1,265 @@
#include "cpu_attn_vec.hpp"
#include "cpu_attn_vec16.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu_attn_amx.hpp"
#define AMX_DISPATCH(...) \
case cpu_attention::ISA::AMX: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
#ifdef __aarch64__
#include "cpu_attn_neon.hpp"
#define NEON_DISPATCH(...) \
case cpu_attention::ISA::NEON: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
#endif // #ifdef __aarch64__
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
return __VA_ARGS__(); \
}
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
[&] { \
switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
default: { \
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
std::to_string(HEAD_DIM)); \
} \
} \
}()
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
case cpu_attention::ISA::VEC16: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
default: { \
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
} \
} \
}()
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split) {
cpu_attention::ISA isa;
if (isa_hint == "amx") {
isa = cpu_attention::ISA::AMX;
} else if (isa_hint == "vec") {
isa = cpu_attention::ISA::VEC;
} else if (isa_hint == "vec16") {
isa = cpu_attention::ISA::VEC16;
} else if (isa_hint == "neon") {
isa = cpu_attention::ISA::NEON;
} else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
}
cpu_attention::AttentionScheduler::ScheduleInput input;
input.num_reqs = num_req;
input.num_heads_q = num_heads_q;
input.num_heads_kv = num_heads_kv;
input.head_dim = head_dim;
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
if (window_size != -1) {
input.left_sliding_window_size = window_size - 1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = window_size - 1;
}
} else {
input.left_sliding_window_size = -1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = -1;
}
}
input.casual = casual;
input.isa = isa;
input.enable_kv_split = enable_kv_split;
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa, [&]() {
input.elem_size = sizeof(scalar_t);
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
input.output_buffer_elem_size =
sizeof(attn_impl::partial_output_buffer_t);
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
});
});
});
cpu_attention::AttentionScheduler scheduler;
torch::Tensor metadata = scheduler.schedule(input);
return metadata;
}
void cpu_attn_reshape_and_cache(
const torch::Tensor& key, // [token_num, head_num, head_size]
const torch::Tensor& value, // [token_num, head_num, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor& slot_mapping, const std::string& isa) {
TORCH_CHECK_EQ(key.dim(), 3);
TORCH_CHECK_EQ(value.dim(), 3);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
TORCH_CHECK_EQ(key.stride(2), 1);
TORCH_CHECK_EQ(value.stride(2), 1);
const int64_t token_num = key.size(0);
const int64_t key_token_num_stride = key.stride(0);
const int64_t value_token_num_stride = value.stride(0);
const int64_t head_num = value.size(1);
const int64_t key_head_num_stride = key.stride(1);
const int64_t value_head_num_stride = value.stride(1);
const int64_t num_blocks = key_cache.size(0);
const int64_t num_blocks_stride = key_cache.stride(0);
const int64_t cache_head_num_stride = key_cache.stride(1);
const int64_t block_size = key_cache.size(2);
const int64_t block_size_stride = key_cache.stride(2);
const int64_t head_dim = key.size(-1);
cpu_attention::ISA isa_tag = [&]() {
if (isa == "amx") {
return cpu_attention::ISA::AMX;
} else if (isa == "vec") {
return cpu_attention::ISA::VEC;
} else if (isa == "vec16") {
return cpu_attention::ISA::VEC16;
} else if (isa == "neon") {
return cpu_attention::ISA::NEON;
} else {
TORCH_CHECK(false, "Invalid ISA type: " + isa);
}
}();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() {
attn_impl::reshape_and_cache(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), token_num,
key_token_num_stride, value_token_num_stride, head_num,
key_head_num_stride, value_head_num_stride, num_blocks,
num_blocks_stride, cache_head_num_stride, block_size,
block_size_stride);
});
});
});
}
void cpu_attention_with_kv_cache(
const torch::Tensor& query, // [num_tokens, num_heads, head_size]
const torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& output, // [num_tokens, num_heads, head_size]
const torch::Tensor& query_start_loc, // [num_tokens + 1]
const torch::Tensor& seq_lens, // [num_tokens]
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes, // [num_heads]
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, // [num_tokens, max_block_num]
const double softcap, const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux // [num_heads]
) {
TORCH_CHECK_EQ(query.dim(), 3);
TORCH_CHECK_EQ(query.stride(2), 1);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
cpu_attention::AttentionInput input;
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
scheduler_metadata.data_ptr());
input.num_tokens = query.size(0);
input.num_heads = query.size(1);
input.num_kv_heads = key_cache.size(1);
input.block_size = key_cache.size(2);
input.query = query.data_ptr();
input.query_num_tokens_stride = query.stride(0);
input.query_num_heads_stride = query.stride(1);
input.cache_num_blocks_stride = key_cache.stride(0);
input.cache_num_kv_heads_stride = key_cache.stride(1);
input.blt_num_tokens_stride = block_table.stride(0);
input.key_cache = key_cache.data_ptr();
input.value_cache = value_cache.data_ptr();
input.output = output.data_ptr();
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
input.block_table = block_table.data_ptr<int32_t>();
input.alibi_slopes =
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
// For now sink must be bf16
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
input.scale = scale;
input.causal = causal;
input.sliding_window_left = sliding_window_left;
input.sliding_window_right = sliding_window_right;
if (input.causal) {
// to make boundary calculation easier
input.sliding_window_right = 0;
}
float softcap_fp32 = softcap;
input.softcap = softcap_fp32;
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] {
CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
});
});
}

511
csrc/cpu/cpu_attn_amx.hpp Normal file
View File

@@ -0,0 +1,511 @@
#ifndef CPU_ATTN_AMX_HPP
#define CPU_ATTN_AMX_HPP
#include "cpu_attn_impl.hpp"
namespace cpu_attention {
namespace {
// AMX specific
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
typedef struct __tile_config {
uint8_t palette_id = 1;
uint8_t start_row = 0;
uint8_t reserved_0[14] = {0};
uint16_t colsb[16] = {0};
uint8_t rows[16] = {0};
} __tilecfg;
// 2-2-4 pattern, for 16 < m <= 32
// TILE 0, 1: load A matrix, row num should be 16, m - 16
// TILE 2, 3: load B matrix, row num should be 16
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
// - 16
template <typename kv_cache_t>
class TileGemm224 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
void* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
}
};
template <>
class TileGemm224<c10::BFloat16> {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
c10::BFloat16* __restrict__ a_tile,
c10::BFloat16* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
const int32_t k_times =
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM;
const int64_t a_tile_stride = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return AMX_TILE_ROW_BYTES;
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return lda * sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// k_cache is prepacked
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// v_cache is prepacked
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
// k_cache, v_cache are prepacked
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
// logits_buffer, output_buffer are not prepacked
float* __restrict__ c_tile_4 = c_tile;
float* __restrict__ c_tile_5 =
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc;
float* __restrict__ c_tile_7 =
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
const int32_t c_tile_stride = ldc * sizeof(float);
if (accum_c) {
_tile_loadd(4, c_tile_4, c_tile_stride);
_tile_loadd(5, c_tile_5, c_tile_stride);
_tile_loadd(6, c_tile_6, c_tile_stride);
_tile_loadd(7, c_tile_7, c_tile_stride);
} else {
_tile_zero(4);
_tile_zero(5);
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
_tile_dpbf16ps(4, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
_tile_dpbf16ps(5, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_dpbf16ps(6, 1, 2);
_tile_dpbf16ps(7, 1, 3);
// update ptrs
if constexpr (phase == AttentionGemmPhase::QK) {
// Q buffer is prepacked
a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// P buffer is not prepacked
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
_tile_stored(4, c_tile_4, c_tile_stride);
_tile_stored(5, c_tile_5, c_tile_stride);
_tile_stored(6, c_tile_6, c_tile_stride);
_tile_stored(7, c_tile_7, c_tile_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
const int32_t m_0 = AMX_TILE_ROW_NUM;
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
config.rows[0] = m_0;
config.rows[1] = m_1;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = m_0;
config.rows[5] = m_0;
config.rows[6] = m_1;
config.rows[7] = m_1;
_tile_loadconfig(&config);
}
};
// 1-2-2 pattern, for 0 < m <= 16
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
// m, m
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
// num should be 16
// TILE 6, 7, (6, 7): store results C matrix, row num should be
// m
template <typename kv_cache_t>
class TileGemm122 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
void* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
}
};
template <>
class TileGemm122<c10::BFloat16> {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
c10::BFloat16* __restrict__ a_tile,
c10::BFloat16* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
c10::BFloat16* __restrict__ a_tile_1 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
const int64_t a_tile_stride = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return AMX_TILE_ROW_BYTES;
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return lda * sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// k_cache is prepacked
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// v_cache is prepacked
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_4 =
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
c10::BFloat16* __restrict__ b_tile_5 =
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
int64_t b_stride = AMX_TILE_ROW_BYTES;
float* __restrict__ c_tile_6 = c_tile;
float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float);
int64_t c_stride = ldc * sizeof(float);
const int32_t k_times =
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
const int32_t k_group_times = k_times / 2;
const bool has_tail = (k_times % 2 == 1);
if (accum_c) {
_tile_loadd(6, c_tile_6, c_stride);
_tile_loadd(7, c_tile_7, c_stride);
} else {
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_group_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_stream_loadd(4, b_tile_4, b_stride);
_tile_dpbf16ps(6, 1, 4);
_tile_stream_loadd(5, b_tile_5, b_stride);
_tile_dpbf16ps(7, 1, 5);
// update ptrs
if constexpr (phase == AttentionGemmPhase::QK) {
// Q buffer is prepacked
a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// P buffer is not prepacked
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
}
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
if (has_tail) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
}
_tile_stored(6, c_tile_6, c_stride);
_tile_stored(7, c_tile_7, c_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
config.rows[0] = m;
config.rows[1] = m;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = AMX_TILE_ROW_NUM;
config.rows[5] = AMX_TILE_ROW_NUM;
config.rows[6] = m;
config.rows[7] = m;
_tile_loadconfig(&config);
}
};
} // namespace
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = scalar_t;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = scalar_t;
constexpr static int64_t BlockSizeAlignment =
AMX_TILE_ROW_BYTES /
sizeof(kv_cache_t); // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 32;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::AMX;
constexpr static bool scale_on_logits = true;
public:
AttentionImpl() : current_q_head_num_(0) {
// Use all columns in AMX tiles
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
}
~AttentionImpl() { _tile_release(); }
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
if (q_head_num > AMX_TILE_ROW_NUM) {
if (q_head_num != current_q_head_num_) {
current_q_head_num_ = q_head_num;
TileGemm224<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
}
attention<TileGemm224<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
} else {
if (q_head_num != current_q_head_num_) {
current_q_head_num_ = q_head_num;
TileGemm122<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
}
attention<TileGemm122<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment * head_dim;
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment * (AMX_TILE_ROW_BYTES / 4);
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return block_size * HeadDimAlignment;
}
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
scalar_t* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, const float scale) {
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
constexpr int64_t head_elem_num_pre_block =
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
int32_t idx = 0;
int8_t* __restrict__ q_buffer_iter = reinterpret_cast<int8_t*>(q_buffer);
for (int32_t q_num_idx = 0; q_num_idx < q_num;
++q_num_idx, src += q_num_stride) {
scalar_t* __restrict__ src_iter = src;
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv;
++q_head_idx, src_iter += q_head_stride) {
vec_op::unroll_loop<int32_t, head_size_block_num>(
[&](int32_t head_size_block_idx) {
// Use INT8Vec64 for 64 bytes block
vec_op::INT8Vec64 vec(src_iter + head_size_block_idx *
head_elem_num_pre_block);
vec.save(q_buffer_iter + head_size_block_idx * AMX_TILE_BYTES);
});
++idx;
q_buffer_iter += AMX_TILE_ROW_BYTES;
if ((idx & (AMX_TILE_ROW_NUM - 1)) == 0) {
// head is in another amx tile
q_buffer_iter -= AMX_TILE_ROW_NUM * AMX_TILE_ROW_BYTES;
q_buffer_iter += head_size_block_num * AMX_TILE_BYTES;
}
}
}
}
// reshape KV to AMX friendly layout
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
// For AMX 2D tiles, size of each line is 64 bytes
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
// For AMX B martix, N always is 16
constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4;
constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t);
// For now suppose block_size is divisible by amx_tile_column_num
TORCH_CHECK_EQ(block_size % amx_b_tile_k_size, 0);
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key
// Head elements should be packed as quand-words and stored in token
// groups with (quadword_stride/4) tokens
constexpr int64_t token_num_per_group = amx_tile_row_size / 4;
static_assert(head_dim % (4 / sizeof(scalar_t)) == 0);
constexpr int64_t quadword_num = head_dim / (4 / sizeof(scalar_t));
const int32_t* key_start_quadword_ptr =
reinterpret_cast<const int32_t*>(
key + token_idx * key_token_num_stride +
head_idx * key_head_num_stride);
const int64_t group_idx = block_offset / token_num_per_group;
const int64_t group_offset = block_offset % token_num_per_group;
constexpr int64_t quadword_num_per_group =
token_num_per_group * quadword_num;
int32_t* key_cache_start_ptr =
reinterpret_cast<int32_t*>(key_cache +
block_idx * num_blocks_stride +
head_idx * cache_head_num_stride) +
group_idx * quadword_num_per_group + group_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; j < quadword_num;
i += token_num_per_group, ++j) {
key_cache_start_ptr[i] = key_start_quadword_ptr[j];
}
}
{
// Write Value
// Different from Key, block_size dimension is packed rather than
// head_size dimension block_size dimension is packed as quand-words;
constexpr int64_t token_num_per_sub_group = 4 / sizeof(scalar_t);
const int64_t token_num_per_group = block_size;
constexpr int64_t head_elems_per_group = amx_b_tile_n_size;
const int64_t group_size = token_num_per_group * head_elems_per_group;
// For now suppose head_dim is divisible by amx_b_tile_n_size
static_assert(head_dim % head_elems_per_group == 0);
constexpr int64_t group_num = head_dim / head_elems_per_group;
const int64_t sub_group_idx = block_offset / token_num_per_sub_group;
const int64_t sub_group_offset =
block_offset % token_num_per_sub_group;
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride +
sub_group_idx * token_num_per_sub_group * amx_b_tile_n_size +
sub_group_offset;
for (int64_t i = 0; i < group_num; ++i) {
#pragma GCC unroll head_elems_per_group
for (int64_t j = 0, k = 0; j < head_elems_per_group;
++j, k += token_num_per_sub_group) {
value_cache_start_ptr[k] = value_start_ptr[j];
}
value_start_ptr += head_elems_per_group;
value_cache_start_ptr += group_size;
}
}
}
}
}
private:
alignas(64) __tilecfg amx_tile_config_;
int32_t current_q_head_num_;
};
} // namespace cpu_attention
#endif

2000
csrc/cpu/cpu_attn_impl.hpp Normal file

File diff suppressed because it is too large Load Diff

113
csrc/cpu/cpu_attn_macros.h Normal file
View File

@@ -0,0 +1,113 @@
#ifndef CPU_ATTN_MACROS_H
#define CPU_ATTN_MACROS_H
// x86_64
#ifdef __x86_64__
#define FAST_SPINNING _mm_pause();
#ifdef __AVX512F__
#define DEFINE_FAST_EXP \
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); \
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); \
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); \
const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); \
const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); \
const __m512 vec_exp_log2ef = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); \
const __m512 vec_half = _mm512_set1_ps(0.5f); \
const __m512 vec_one = _mm512_set1_ps(1.f); \
const __m512 vec_zero = _mm512_set1_ps(0.f); \
const __m512 vec_two = _mm512_set1_ps(2.f); \
const __m512 vec_ln2f = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); \
const __m512 vec_ln_flt_min = \
_mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); \
const __m512 vec_ln_flt_max = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \
__m512 values = vec.reg; \
auto less_ln_flt_min_mask = \
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); \
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); \
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); \
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); \
auto vec_fx_i = _mm512_cvt_roundps_epi32( \
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); \
vec_fx = _mm512_cvtepi32_ps(vec_fx_i); \
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); \
auto vec_res = \
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); \
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); \
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); \
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); \
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); \
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); \
vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, \
vec_two_pow_n, vec_zero); \
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); \
vec_res = _mm512_mul_ps(vec_res, vec_two); \
vec_op::FP32Vec16 res(vec_res); \
return res; \
};
#endif
#endif
#ifdef __aarch64__
// Implementation copied from Arm Optimized Routines (expf AdvSIMD)
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
#include <limits>
#define DEFINE_FAST_EXP \
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \
const float ln2_hi = 0x1.62e4p-1f; \
const float ln2_lo = 0x1.7f7d1cp-20f; \
const float c0 = 0x1.0e4020p-7f; \
const float c2 = 0x1.555e66p-3f; \
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \
const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \
const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \
const float32x4_t inf = \
vdupq_n_f32(std::numeric_limits<float>::infinity()); \
const float32x4_t zero = vdupq_n_f32(0.0f); \
auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \
r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \
float32x4_t r2 = vmulq_f32(r, r); \
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \
q = vfmaq_f32(q, p, r2); \
p = vmulq_f32(c4, r); \
float32x4_t poly = vfmaq_f32(p, q, r2); \
poly = vfmaq_f32(scale, poly, scale); \
const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \
const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \
poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \
}; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \
float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \
result.val[1] = neon_expf(vec.reg.val[1]); \
result.val[2] = neon_expf(vec.reg.val[2]); \
result.val[3] = neon_expf(vec.reg.val[3]); \
return vec_op::FP32Vec16(result); \
};
#endif // __aarch64__
#endif

386
csrc/cpu/cpu_attn_neon.hpp Normal file
View File

@@ -0,0 +1,386 @@
#ifndef CPU_ATTN_NEON_HPP
#define CPU_ATTN_NEON_HPP
#include "cpu_attn_impl.hpp"
#include <arm_neon.h>
#include <type_traits>
namespace cpu_attention {
namespace {
#define BLOCK_SIZE_ALIGNMENT 32
#define HEAD_SIZE_ALIGNMENT 32
#define MAX_Q_HEAD_NUM_PER_ITER 16
// These do not use vectorized class for loading / converting
// because csrc/cpu/cpu_types_arm.hpp does not have fallback options
// for vec_op::BF16Vec* / vec_op::BF16Vec* on Arm HW that
// doesn't support BF16.
// We don't use vec_op::FP32Vec* or vec_op::FP16Vec* for consistency.
template <typename kv_cache_t>
FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, float32x4_t& b0,
float32x4_t& b1);
template <>
FORCE_INLINE void load_row8_B_as_f32<float>(const float* p, float32x4_t& b0,
float32x4_t& b1) {
b0 = vld1q_f32(p + 0);
b1 = vld1q_f32(p + 4);
}
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::Half>(const c10::Half* p,
float32x4_t& b0,
float32x4_t& b1) {
const float16_t* h = reinterpret_cast<const float16_t*>(p);
float16x8_t v = vld1q_f16(h);
b0 = vcvt_f32_f16(vget_low_f16(v));
b1 = vcvt_f32_f16(vget_high_f16(v));
}
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
float32x4_t& b0,
float32x4_t& b1) {
const uint16_t* u = reinterpret_cast<const uint16_t*>(p);
#ifdef ARM_BF16_SUPPORT
uint16x8_t u0 = vld1q_u16(u);
bfloat16x8_t bf0 = vreinterpretq_bf16_u16(u0);
b0 = vcvtq_low_f32_bf16(bf0);
b1 = vcvtq_high_f32_bf16(bf0);
#else
uint16x8_t x0 = vld1q_u16(u);
uint32x4_t lo = vshlq_n_u32(vmovl_u16(vget_low_u16(x0)), 16);
uint32x4_t hi = vshlq_n_u32(vmovl_u16(vget_high_u16(x0)), 16);
b0 = vreinterpretq_f32_u32(lo);
b1 = vreinterpretq_f32_u32(hi);
#endif
}
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs
// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
// #FMLAs = (K // 4) * (4 * 2 * M)
// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
template <int32_t M, typename kv_cache_t>
FORCE_INLINE void gemm_micro_neon_fmla_Mx8_Ku4(
const float* __restrict A, // [M x K],
const kv_cache_t* __restrict B, // [K x 8],
float* __restrict C, // [M x 8],
int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
// kernel supports max M of 8, as it'd spill for larger M
static_assert(1 <= M && M <= 8, "M must be in [1,8]");
// helpers for per-M codegen
#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
#define IF_M(i) if constexpr (M > (i))
// A row base pointers
#define DECL_A(i) const float* a##i = A + (i) * lda;
ROWS_APPLY(DECL_A)
#undef DECL_A
// declare 2 accumulators per row of M
#define DECL_ACC(i) float32x4_t acc##i##_0, acc##i##_1;
ROWS_APPLY(DECL_ACC)
#undef DECL_ACC
// initialize accumulators
#define INIT_ACC(i) \
IF_M(i) { \
if (accumulate) { \
acc##i##_0 = vld1q_f32(C + (i) * ldc + 0); \
acc##i##_1 = vld1q_f32(C + (i) * ldc + 4); \
} else { \
acc##i##_0 = vdupq_n_f32(0.f); \
acc##i##_1 = vdupq_n_f32(0.f); \
} \
}
ROWS_APPLY(INIT_ACC)
#undef INIT_ACC
int32_t k = 0;
// K unrolled by 4
for (; k + 3 < K; k += 4) {
// load A[k..k+3] for each active row (M)
#define LOAD_A4(i) \
float32x4_t a##i##v; \
IF_M(i) a##i##v = vld1q_f32(a##i + k);
ROWS_APPLY(LOAD_A4)
#undef LOAD_A4
// helper: FMA lane L from aiv
#define FMAS_LANE(i, aiv, L) \
IF_M(i) { \
acc##i##_0 = vfmaq_laneq_f32(acc##i##_0, b0, aiv, L); \
acc##i##_1 = vfmaq_laneq_f32(acc##i##_1, b1, aiv, L); \
}
// k + 0
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 0) * ldb, b0, b1);
#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
ROWS_APPLY(STEP_K0)
#undef STEP_K0
}
// k + 1
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 1) * ldb, b0, b1);
#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
ROWS_APPLY(STEP_K1)
#undef STEP_K1
}
// k + 2
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 2) * ldb, b0, b1);
#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
ROWS_APPLY(STEP_K2)
#undef STEP_K2
}
// k + 3
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 3) * ldb, b0, b1);
#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
ROWS_APPLY(STEP_K3)
#undef STEP_K3
}
#undef FMAS_LANE
}
// K tail
for (; k < K; ++k) {
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb, b0, b1);
#define TAIL_ROW(i) \
IF_M(i) { \
float32x4_t ai = vdupq_n_f32(*(a##i + k)); \
acc##i##_0 = vfmaq_f32(acc##i##_0, b0, ai); \
acc##i##_1 = vfmaq_f32(acc##i##_1, b1, ai); \
}
ROWS_APPLY(TAIL_ROW)
#undef TAIL_ROW
}
// store accumulators to C
#define STORE_ROW(i) \
IF_M(i) { \
vst1q_f32(C + (i) * ldc + 0, acc##i##_0); \
vst1q_f32(C + (i) * ldc + 4, acc##i##_1); \
}
ROWS_APPLY(STORE_ROW)
#undef STORE_ROW
#undef ROWS_APPLY
#undef IF_M
}
template <int32_t N, typename kv_cache_t>
FORCE_INLINE void gemm_macro_neon_fmla_Mx8_Ku4(const float* __restrict A,
const kv_cache_t* __restrict B,
float* __restrict C, int32_t M,
int32_t K, int64_t lda,
int64_t ldb, int64_t ldc,
bool accumulate) {
// micro kernel is Mx8
static_assert(N % 8 == 0, "N must be a multiple of 8");
for (int32_t m = 0; m < M;) {
int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
const float* Ab = A + m * lda;
float* Cb = C + m * ldc;
for (int32_t n = 0; n < N; n += 8) {
const kv_cache_t* Bn = B + n;
float* Cn = Cb + n;
switch (mb) {
case 8:
gemm_micro_neon_fmla_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
case 4:
gemm_micro_neon_fmla_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
case 2:
gemm_micro_neon_fmla_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
default:
gemm_micro_neon_fmla_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
}
}
// no tail loop for N as it's guaranteed to be a multiple of 8
m += mb;
}
}
template <typename kv_cache_t>
class TileGemmNeonFMLA {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
if constexpr (phase == AttentionGemmPhase::QK) {
gemm_macro_neon_fmla_Mx8_Ku4<BLOCK_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
} else {
gemm_macro_neon_fmla_Mx8_Ku4<HEAD_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
accum_c);
}
}
};
} // namespace
// this is similar to "ISA::VEC" at the moment
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
BLOCK_SIZE_ALIGNMENT; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
HEAD_SIZE_ALIGNMENT; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::NEON;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
static_assert(HeadDim % HeadDimAlignment == 0);
// the gemm micro kernel is Mx8
static_assert(HeadDimAlignment % 8 == 0);
static_assert(BlockSizeAlignment % 8 == 0);
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemmNeonFMLA<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
// Copy q to q_buffer and cast it to fp32
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
float* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
static_assert(head_dim % 16 == 0);
constexpr int32_t unroll_size = head_dim / 16;
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
scalar_t* __restrict__ curr_q =
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
float* __restrict__ curr_q_buffer =
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
q_head_idx * head_dim;
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
load_vec_t vec(curr_q);
vec_op::FP32Vec16 fp32_vec(vec);
fp32_vec = fp32_vec * scale_vec;
fp32_vec.save(curr_q_buffer);
curr_q += 16;
curr_q_buffer += 16;
});
}
}
}
// reshape K as column-major and V as row-major
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key
const scalar_t* key_start_ptr = key +
token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_cache_start_ptr =
key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_cache_start_ptr[j] = key_start_ptr[i];
}
}
{
// Write Value
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset * head_dim;
std::memcpy(value_cache_start_ptr, value_start_ptr,
sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#endif // #ifndef CPU_ATTN_NEON_HPP

248
csrc/cpu/cpu_attn_vec.hpp Normal file
View File

@@ -0,0 +1,248 @@
#ifndef CPU_ATTN_VEC_HPP
#define CPU_ATTN_VEC_HPP
#include "cpu_attn_impl.hpp"
namespace cpu_attention {
namespace {
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
template <typename kv_cache_t>
class TileGemm82 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 8);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
kv_cache_t* __restrict__ curr_b_1 = b_tile + 16;
float* __restrict__ curr_c_0 = c_tile;
float* __restrict__ curr_c_1 = c_tile + 16;
vec_op::FP32Vec16 c_regs[M * 2];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
float* __restrict__ curr_m_c_1 = curr_c_1;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
// update
curr_m_c_0 += ldc;
curr_m_c_1 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
load_vec_t b_1_reg(curr_b_1);
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
curr_b_1 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2].save(curr_c_0);
c_regs[i * 2 + 1].save(curr_c_1);
// update
curr_c_0 += ldc;
curr_c_1 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
32; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
32; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 8;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm82<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
// Copy q to q_buffer and cast it to fp32
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
float* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
static_assert(head_dim % 16 == 0);
constexpr int32_t unroll_size = head_dim / 16;
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
scalar_t* __restrict__ curr_q =
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
float* __restrict__ curr_q_buffer =
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
q_head_idx * head_dim;
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
load_vec_t vec(curr_q);
vec_op::FP32Vec16 fp32_vec(vec);
fp32_vec = fp32_vec * scale_vec;
fp32_vec.save(curr_q_buffer);
curr_q += 16;
curr_q_buffer += 16;
});
}
}
}
// reshape K as column-major and V as row-major
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key as column-major
const scalar_t* key_start_ptr = key +
token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_cache_start_ptr =
key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_cache_start_ptr[j] = key_start_ptr[i];
}
}
{
// Write Value as row-major
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset * head_dim;
std::memcpy(value_cache_start_ptr, value_start_ptr,
sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#endif

171
csrc/cpu/cpu_attn_vec16.hpp Normal file
View File

@@ -0,0 +1,171 @@
#ifndef CPU_ATTN_VEC16_HPP
#define CPU_ATTN_VEC16_HPP
#include "cpu_attn_vec.hpp"
namespace cpu_attention {
namespace {
// 16-1-16 pattern, 16 regs for A, 1 regs for B, 16 regs for C, [16, K] @ [k,
// 16]
template <typename kv_cache_t>
class TileGemm161 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 9:
case 10:
case 11:
case 12:
gemm_micro<12>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 13:
case 14:
case 15:
case 16:
gemm_micro<16>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 16);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
float* __restrict__ curr_c_0 = c_tile;
vec_op::FP32Vec16 c_regs[M];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i] = vec_op::FP32Vec16(curr_m_c_0);
// update
curr_m_c_0 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i] = c_regs[i] + a_reg * fp32_b_0_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i].save(curr_c_0);
// update
curr_c_0 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC16, scalar_t, head_dim>
: public AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
16; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
16; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 16;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC16;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm161<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
};
} // namespace cpu_attention
#endif

25
csrc/cpu/cpu_types.hpp Normal file
View File

@@ -0,0 +1,25 @@
#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP
#if defined(__x86_64__)
// x86 implementation
#include "cpu_types_x86.hpp"
#elif defined(__POWER9_VECTOR__)
// ppc implementation
#include "cpu_types_vsx.hpp"
#elif defined(__s390x__)
// s390 implementation
#include "cpu_types_vxe.hpp"
#elif defined(__aarch64__)
// arm implementation
#include "cpu_types_arm.hpp"
#else
#warning "unsupported vLLM cpu implementation, vLLM will compile with scalar"
#include "cpu_types_scalar.hpp"
#endif
#ifdef _OPENMP
#include <omp.h>
#endif
#endif

856
csrc/cpu/cpu_types_arm.hpp Normal file
View File

@@ -0,0 +1,856 @@
#include <arm_neon.h>
#include <torch/all.h>
#include <cmath>
#if defined(__APPLE__)
#include "omp.h"
#endif
namespace vec_op {
#ifdef ARM_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
// Number of elements in single ASIMD vector of given Datatype
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
};
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
};
struct FP32Vec8;
struct FP32Vec16;
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
float16x8_t reg;
explicit FP16Vec8(const void* ptr)
: reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {};
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); }
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
float16x8x2_t reg;
explicit FP16Vec16(const void* ptr) {
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
}
explicit FP16Vec16(const FP32Vec16& vec);
void save(void* ptr) const {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
}
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
if (full_blocks > 0) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
if (full_blocks > 1) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
}
}
// Note: below is the unrolled version of the following code:
//
// for (int i = 0; i < remainder; ++i) {
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
// vgetq_lane_f16(temp, i);
// }
//
// For macOS build (Clang), the arm/neon intrinsics function
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
// time.
if (remainder > 0) {
float16x8_t temp = reg.val[full_blocks];
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
switch (remainder) {
case 1:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
break;
case 2:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
break;
case 3:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
break;
case 4:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
break;
case 5:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
break;
case 6:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
break;
case 7:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
break;
default:
break;
}
}
}
};
#ifdef ARM_BF16_SUPPORT
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
bfloat16x8_t reg;
explicit BF16Vec8(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {};
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
explicit BF16Vec8(const FP32Vec8&);
explicit BF16Vec8(float32x4x2_t v)
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
bfloat16x8x2_t reg;
explicit BF16Vec16(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
explicit BF16Vec16(const FP32Vec16&);
explicit BF16Vec16(float32x4x4_t v)
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_bf16(
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
bfloat16x8_t temp = reg.val[full_blocks];
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
}
};
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
bfloat16x8x4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {};
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
explicit BF16Vec32(const BF16Vec8& vec8_data)
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_bf16(
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
bfloat16x8_t temp = reg.val[full_blocks];
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
base[0] = vgetq_lane_bf16(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
}
};
};
#endif
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
float32x4_t reg;
float values[VEC_ELEM_NUM];
};
float32x4_t reg;
explicit FP32Vec4(float v) : reg(vdupq_n_f32(v)) {};
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {};
explicit FP32Vec4(float32x4_t data) : reg(data) {};
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
float32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
float32x4x2_t reg;
explicit FP32Vec8(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v)}) {};
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
explicit FP32Vec8(const float* ptr)
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
};
explicit FP32Vec8(float16x8_t v)
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec8(bfloat16x8_t v)
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
explicit FP32Vec8(const BF16Vec8& v)
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
#endif
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&answer, &ar](int i) { answer += ar.values[i]; });
return answer;
}
FP32Vec8 exp() const {
AliasReg ar;
ar.reg = reg;
float32x2_t exp_vec0 = {expf(ar.values[0]), expf(ar.values[1])};
float32x2_t exp_vec1 = {expf(ar.values[2]), expf(ar.values[3])};
float32x2_t exp_vec2 = {expf(ar.values[4]), expf(ar.values[5])};
float32x2_t exp_vec3 = {expf(ar.values[6]), expf(ar.values[7])};
float32x4_t result0 = vcombine_f32(exp_vec0, exp_vec1);
float32x4_t result1 = vcombine_f32(exp_vec2, exp_vec3);
float32x4x2_t result;
result.val[0] = result0;
result.val[1] = result1;
return FP32Vec8(result);
}
FP32Vec8 tanh() const {
AliasReg ar;
ar.reg = reg;
float32x2_t tanh_vec0 = {tanhf(ar.values[0]), tanhf(ar.values[1])};
float32x2_t tanh_vec1 = {tanhf(ar.values[2]), tanhf(ar.values[3])};
float32x2_t tanh_vec2 = {tanhf(ar.values[4]), tanhf(ar.values[5])};
float32x2_t tanh_vec3 = {tanhf(ar.values[6]), tanhf(ar.values[7])};
float32x4_t result0 = vcombine_f32(tanh_vec0, tanh_vec1);
float32x4_t result1 = vcombine_f32(tanh_vec2, tanh_vec3);
float32x4x2_t result;
result.val[0] = result0;
result.val[1] = result1;
return FP32Vec8(result);
}
FP32Vec8 er() const {
AliasReg ar;
ar.reg = reg;
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])),
static_cast<float32_t>(erf(ar.values[1]))};
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])),
static_cast<float32_t>(erf(ar.values[3]))};
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
static_cast<float32_t>(erf(ar.values[5]))};
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
static_cast<float32_t>(erf(ar.values[7]))};
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
float32x4x2_t result;
result.val[0] = result0;
result.val[1] = result1;
return FP32Vec8(result);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]),
vmulq_f32(reg.val[1], b.reg.val[1])}));
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[1], b.reg.val[1])}));
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]),
vsubq_f32(reg.val[1], b.reg.val[1])}));
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[1], b.reg.val[1])}));
}
void save(float* ptr) const {
vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]);
}
};
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
int32x4x4_t reg;
explicit INT32Vec16(const void* ptr) {
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
}
void save(int32_t* ptr) const {
vst1q_s32(ptr, reg.val[0]);
vst1q_s32(ptr + 4, reg.val[1]);
vst1q_s32(ptr + 8, reg.val[2]);
vst1q_s32(ptr + 12, reg.val[3]);
};
void save(int32_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_s32(
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
int32x4_t temp = reg.val[full_blocks];
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
}
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
float32x4x4_t reg;
float values[VEC_ELEM_NUM];
};
float32x4x4_t reg;
explicit FP32Vec16(float v)
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
explicit FP32Vec16()
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
vmovq_n_f32(0.0)}) {}
explicit FP32Vec16(const float* ptr)
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
vld1q_f32(ptr + 12)}) {}
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {}
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(bfloat16x8x2_t v)
: reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]),
vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {};
#endif
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
};
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(const BF16Vec16& v)
: reg({vcvtq_low_f32_bf16(v.reg.val[0]),
vcvtq_high_f32_bf16(v.reg.val[0]),
vcvtq_low_f32_bf16(v.reg.val[1]),
vcvtq_high_f32_bf16(v.reg.val[1])}) {};
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
#endif
explicit FP32Vec16(const FP16Vec16& v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
};
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
};
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[1], b.reg.val[1]),
vaddq_f32(reg.val[2], b.reg.val[2]),
vaddq_f32(reg.val[3], b.reg.val[3])}));
};
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]),
vmulq_f32(reg.val[1], b.reg.val[1]),
vmulq_f32(reg.val[2], b.reg.val[2]),
vmulq_f32(reg.val[3], b.reg.val[3])}));
};
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]),
vsubq_f32(reg.val[1], b.reg.val[1]),
vsubq_f32(reg.val[2], b.reg.val[2]),
vsubq_f32(reg.val[3], b.reg.val[3])}));
};
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[1], b.reg.val[1]),
vdivq_f32(reg.val[2], b.reg.val[2]),
vdivq_f32(reg.val[3], b.reg.val[3])}));
};
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(float32x4x4_t(
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
};
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
vmaxq_f32(b.reg.val[1], reg.val[1]),
vmaxq_f32(b.reg.val[2], reg.val[2]),
vmaxq_f32(b.reg.val[3], reg.val[3])}));
};
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
float32x4x4_t temp;
for (int i = 0; i < full_blocks; i++)
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
}
if (remainder > 1) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
}
if (remainder > 2) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
}
return FP32Vec16(temp);
};
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({
vminq_f32(b.reg.val[0], reg.val[0]),
vminq_f32(b.reg.val[1], reg.val[1]),
vminq_f32(b.reg.val[2], reg.val[2]),
vminq_f32(b.reg.val[3], reg.val[3]),
}));
};
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
float32x4x4_t temp;
for (int i = 0; i < full_blocks; i++)
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
}
if (remainder > 1) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
}
if (remainder > 2) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
}
return FP32Vec16(temp);
};
FP32Vec16 abs() const {
return FP32Vec16(
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&answer, &ar](int i) { answer += ar.values[i]; });
return answer;
};
float reduce_max() const {
AliasReg ar;
ar.reg = reg;
float max_v = std::numeric_limits<float>::lowest();
unroll_loop<int, VEC_ELEM_NUM>(
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
return max_v;
}
float reduce_min() const {
AliasReg ar;
ar.reg = reg;
float min_v = std::numeric_limits<float>::max();
unroll_loop<int, VEC_ELEM_NUM>(
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
return min_v;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
float answer = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&answer, &start, ar](int i) { answer += ar.values[start + i]; });
return answer;
};
void save(float* ptr) const {
vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]);
vst1q_f32(ptr + 8, reg.val[2]);
vst1q_f32(ptr + 12, reg.val[3]);
};
void save(float* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_f32(
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
float32x4_t temp = reg.val[full_blocks];
float* base = reinterpret_cast<float32_t*>(ptr) +
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
}
}
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int8x16_t reg;
int8_t values[VEC_ELEM_NUM];
};
int8x16_t reg;
explicit INT8Vec16(const FP32Vec16& vec) {
// Convert each 128-bit float32 vector to int32
int32x4_t part0 =
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
int32x4_t part1 =
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
int32x4_t part2 =
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
int32x4_t part3 =
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
// Narrow each 32-bit vector to 8 bits and combine
int8x8_t lower =
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
int8x8_t upper =
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
}
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
void save(int8_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
for (int i = 0; i < full_blocks; i++)
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
if (remainder > 0) {
int8x16_t temp = reg;
int8_t* base =
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
}
};
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
#ifdef ARM_BF16_SUPPORT
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
#endif
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
template <>
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*reinterpret_cast<__fp16*>(ptr) = v;
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
reg.val[0] = vcombine_f16(low_0, high_0);
reg.val[1] = vcombine_f16(low_1, high_1);
};
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
reg = vcombine_f16(lower_half, upper_half);
};
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]);
};
#ifdef ARM_BF16_SUPPORT
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
float32x4_t a1_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[1]));
float32x4_t b0_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[0]));
float32x4_t b0_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[0]));
float32x4_t b1_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[1]));
float32x4_t b1_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[1]));
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a0_low, b0_low);
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a0_high, b0_high);
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a1_low, b1_low);
acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a1_high, b1_high);
};
#endif
#ifdef ARM_BF16_SUPPORT
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
};
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]),
v.reg.val[3])}) {};
#endif
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
#ifdef ARM_BF16_SUPPORT
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
};
#endif
}; // namespace vec_op

View File

@@ -0,0 +1,465 @@
#include <cmath>
#include <cstdint>
#include <cstring>
#include <torch/all.h>
#include "float_convert.hpp"
namespace vec_op {
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
typedef struct f16x8_t {
uint16_t val[8];
} f16x8_t;
typedef struct f16x16_t {
uint16_t val[16];
} f16x16_t;
typedef struct f16x32_t {
uint16_t val[32];
} f16x32_t;
typedef struct f32x4_t {
float val[4];
} f32x4_t;
typedef struct f32x8_t {
float val[8];
} f32x8_t;
typedef struct f32x16_t {
float val[16];
} f32x16_t;
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
};
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T> > >
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
struct FP32Vec8;
struct FP32Vec16;
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
f16x8_t reg;
explicit FP16Vec8(const void* ptr)
: reg(*reinterpret_cast<const f16x8_t*>(ptr)) {};
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const { *reinterpret_cast<f16x8_t*>(ptr) = reg; }
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
f16x16_t reg;
explicit FP16Vec16(const void* ptr)
: reg(*reinterpret_cast<const f16x16_t*>(ptr)) {};
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
f16x8_t reg;
explicit BF16Vec8(const void* ptr)
: reg(*reinterpret_cast<const f16x8_t*>(ptr)) {};
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const { *reinterpret_cast<f16x8_t*>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
f16x16_t reg;
explicit BF16Vec16(const void* ptr)
: reg(*reinterpret_cast<const f16x16_t*>(ptr)) {};
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
f16x32_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const f16x32_t*>(ptr)) {};
explicit BF16Vec32(f16x32_t data) : reg(data) {};
explicit BF16Vec32(BF16Vec8& vec8_data) {
unroll_loop<int, VEC_ELEM_NUM>([&vec8_data, this](int i) {
reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM];
});
}
void save(void* ptr) const { *reinterpret_cast<f16x32_t*>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
f32x4_t reg;
explicit FP32Vec4(float v) {
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec4() {
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec4(const float* ptr)
: reg(*reinterpret_cast<const f32x4_t*>(ptr)) {};
explicit FP32Vec4(f32x4_t data) : reg(data) {};
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
f32x8_t reg;
explicit FP32Vec8(float v) {
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec8() {
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec8(const float* ptr)
: reg(*reinterpret_cast<const f32x8_t*>(ptr)) {};
explicit FP32Vec8(f32x8_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
FP32Vec8(const BF16Vec8& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
float reduce_sum() const {
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result += reg.val[i]; });
return result;
}
FP32Vec8 exp() const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = expf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 tanh() const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = tanhf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 er() const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = erf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator+(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator-(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator/(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
return FP32Vec8(ret);
}
void save(void* ptr) const { *reinterpret_cast<f32x8_t*>(ptr) = reg; }
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
f32x16_t reg;
explicit FP32Vec16(float v) {
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec16() {
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec16(const float* ptr)
: reg(*reinterpret_cast<const f32x16_t*>(ptr)) {};
explicit FP32Vec16(f32x16_t data) : reg(data) {};
FP32Vec16(const FP32Vec4& data) {
unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM];
});
}
FP32Vec16(const FP32Vec8& data) {
unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM];
});
}
FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
explicit FP32Vec16(const FP16Vec16& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const BF16Vec16& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
FP32Vec16 operator*(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 operator+(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 operator-(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 operator/(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 max(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
ret.val[i] = std::max(reg.val[i], b.reg.val[i]);
});
return FP32Vec16(ret);
}
FP32Vec16 min(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
ret.val[i] = std::min(reg.val[i], b.reg.val[i]);
});
return FP32Vec16(ret);
}
FP32Vec16 abs() const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = std::abs(reg.val[i]); });
return FP32Vec16(ret);
}
float reduce_sum() const {
float result = 0.0f;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result += reg.val[i]; });
return result;
}
float reduce_max() const {
float result = std::numeric_limits<float>::lowest();
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result = std::max(reg.val[i], result); });
return result;
}
float reduce_min() const {
float result = std::numeric_limits<float>::max();
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result = std::min(reg.val[i], result); });
return result;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
float sum = 0.0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&sum, &start, this](int i) { sum += reg.val[start + i]; });
return sum;
}
void save(void* ptr) const { *reinterpret_cast<f32x16_t*>(ptr) = reg; }
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
/*
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
c10::Half __attribute__((__may_alias__)) *v_ptr =
reinterpret_cast<c10::Half *>(&v);
*ptr = *(v_ptr + 1);
}
*/
template <>
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
uint16_t fp16 = float_to_fp16(v);
*reinterpret_cast<uint16_t*>(ptr) = fp16;
}
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
unroll_loop<int, FP16Vec16::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
unroll_loop<int, FP16Vec8::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b;
}
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
unroll_loop<int, BF16Vec8::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
unroll_loop<int, BF16Vec16::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); }
}; // namespace vec_op

788
csrc/cpu/cpu_types_vsx.hpp Normal file
View File

@@ -0,0 +1,788 @@
#ifndef CPU_TYPES_VSX_HPP
#define CPU_TYPES_VSX_HPP
#include <altivec.h>
#include <cmath>
#include <algorithm>
#include <torch/all.h>
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
typedef struct ss16x8x2_t {
__vector signed short val[2];
} ss16x8x2_t;
typedef struct ss16x8x4_t {
__vector signed short val[4];
} ss16x8x4_t;
typedef struct f32x4x2_t {
__vector float val[2];
} f32x4x2_t;
typedef struct f32x4x4_t {
__vector float val[4];
} f32x4x4_t;
typedef struct i32x4x4_t {
__vector int32_t val[4];
} i32x4x4_t;
struct FP32Vec8;
struct FP32Vec16;
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit BF16Vec8(const void* ptr)
: reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit BF16Vec16(const void* ptr) {
// Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
void save(void* ptr, const int elem_num) const {
const int clamped_elem = std::max(0, std::min(elem_num, 16));
// Calculate elements to store in each 128-bit part (8 elements each)
const int elements_val0 = std::min(clamped_elem, 8);
const int elements_val1 = std::max(clamped_elem - 8, 0);
// Convert elements to bytes (2 bytes per element)
const size_t bytes_val0 = elements_val0 * sizeof(signed short);
const size_t bytes_val1 = elements_val1 * sizeof(signed short);
signed short* dest = static_cast<signed short*>(ptr);
// Store the first part using vec_xst_len
if (bytes_val0 > 0) {
vec_xst_len(reg.val[0], dest, bytes_val0);
}
// Store the second part if needed
if (bytes_val1 > 0) {
vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1);
}
}
};
const static __vector signed short zero = vec_splats((signed short)0);
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8& vec8_data)
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__vector float reg;
float values[VEC_ELEM_NUM];
};
__vector float reg;
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
f32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x2_t reg;
explicit FP32Vec8(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
}
explicit FP32Vec8() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
}
explicit FP32Vec8(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
}
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
}
explicit FP32Vec8(const BF16Vec8& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::exp(ar.values[0]);
ret.val[0][1] = std::exp(ar.values[1]);
ret.val[0][2] = std::exp(ar.values[2]);
ret.val[0][3] = std::exp(ar.values[3]);
ret.val[1][0] = std::exp(ar.values[4]);
ret.val[1][1] = std::exp(ar.values[5]);
ret.val[1][2] = std::exp(ar.values[6]);
ret.val[1][3] = std::exp(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 tanh() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::tanh(ar.values[0]);
ret.val[0][1] = std::tanh(ar.values[1]);
ret.val[0][2] = std::tanh(ar.values[2]);
ret.val[0][3] = std::tanh(ar.values[3]);
ret.val[1][0] = std::tanh(ar.values[4]);
ret.val[1][1] = std::tanh(ar.values[5]);
ret.val[1][2] = std::tanh(ar.values[6]);
ret.val[1][3] = std::tanh(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 er() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::erf(ar.values[0]);
ret.val[0][1] = std::erf(ar.values[1]);
ret.val[0][2] = std::erf(ar.values[2]);
ret.val[0][3] = std::erf(ar.values[3]);
ret.val[1][0] = std::erf(ar.values[4]);
ret.val[1][1] = std::erf(ar.values[5]);
ret.val[1][2] = std::erf(ar.values[6]);
ret.val[1][3] = std::erf(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
}
};
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
i32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
i32x4x4_t reg;
explicit INT32Vec16(const void* data_ptr) {
reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[1] =
vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[2] =
vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[3] =
vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr));
}
void save(int32_t* ptr) const {
vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr));
}
void save(int32_t* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t));
vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
f32x4x4_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x4_t reg;
explicit FP32Vec16(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
reg.val[2] = vec_splats(v);
reg.val[3] = vec_splats(v);
}
explicit FP32Vec16() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
reg.val[2] = vec_splats(0.0f);
reg.val[3] = vec_splats(0.0f);
}
explicit FP32Vec16(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr);
reg.val[3] = vec_xl(48, ptr);
}
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3];
}
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
}
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const BF16Vec16& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vec_ctf(v.reg.val[0], 0);
reg.val[1] = vec_ctf(v.reg.val[1], 0);
reg.val[2] = vec_ctf(v.reg.val[2], 0);
reg.val[3] = vec_ctf(v.reg.val[3], 0);
}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(f32x4x4_t(
{vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])),
vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])),
vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])),
vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))}));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 max(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
// Create a vector of element indices for each chunk
__vector unsigned int indices = {0, 1, 2, 3};
__vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
// Compute masks for each chunk
__vector unsigned int chunk_offset0 = {0, 0, 0,
0}; // Chunk 0: Elements 0-3
__vector unsigned int chunk_offset1 = {4, 4, 4,
4}; // Chunk 1: Elements 4-7
__vector unsigned int chunk_offset2 = {8, 8, 8,
8}; // Chunk 2: Elements 8-11
__vector unsigned int chunk_offset3 = {12, 12, 12,
12}; // Chunk 3: Elements 12-15
// Compute masks for each chunk
__vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
__vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
__vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
__vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
// Apply masks to compute the result for each chunk
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_max(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_max(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_max(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_max(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]),
vec_min(reg.val[1], b.reg.val[1]),
vec_min(reg.val[2], b.reg.val[2]),
vec_min(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 min(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
vector unsigned int indices = {0, 1, 2, 3};
vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
vector unsigned int chunk_offset0 = {0, 0, 0, 0};
vector unsigned int chunk_offset1 = {4, 4, 4, 4};
vector unsigned int chunk_offset2 = {8, 8, 8, 8};
vector unsigned int chunk_offset3 = {12, 12, 12, 12};
vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_min(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_min(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_min(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_min(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 abs() const {
return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]),
vec_abs(reg.val[2]), vec_abs(reg.val[3])}));
}
float reduce_max() {
__vector float max01 = vec_max(reg.val[0], reg.val[1]);
__vector float max23 = vec_max(reg.val[2], reg.val[3]);
__vector float max_all = vec_max(max01, max23);
__vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8));
temp = vec_max(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_min() {
__vector float min01 = vec_min(reg.val[0], reg.val[1]);
__vector float min23 = vec_min(reg.val[2], reg.val[3]);
__vector float min_all = vec_min(min01, min23);
__vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8));
temp = vec_min(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
float result = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&result, &start, ar](int i) { result += ar.values[start + i]; });
return result;
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr);
}
void save(float* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(float));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(float));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(float));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(float));
vec_xst_len(reg.val[0], ptr, bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16
union AliasReg {
__vector signed char reg;
int8_t values[VEC_NUM_ELEM];
};
__vector signed char reg;
explicit INT8Vec16(const FP32Vec16& vec) {
__vector signed int ret[4];
ret[0] = vec_cts(vec.reg.val[0], 0);
ret[1] = vec_cts(vec.reg.val[1], 0);
ret[2] = vec_cts(vec.reg.val[2], 0);
ret[3] = vec_cts(vec.reg.val[3], 0);
__vector signed short packed1 = vec_packs(ret[0], ret[1]);
__vector signed short packed2 = vec_packs(ret[2], ret[3]);
reg = vec_packs(packed1, packed2);
}
void save(void* ptr) const {
*reinterpret_cast<__vector signed char*>(ptr) = reg;
}
void save(signed char* ptr, const int elem_num) {
vec_xst_len(reg, ptr, static_cast<size_t>(elem_num));
}
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b;
}
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13,
16, 17, 20, 21, 24, 25, 28, 29};
#ifndef _ARCH_PWR10
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
0x00007fff};
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
0x7fc00000};
const static __vector unsigned int sh16 = {16, 16, 16, 16};
const static __vector unsigned int one = {1, 1, 1, 1};
#endif
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
#ifdef _ARCH_PWR10
__vector signed short ret[2];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[0]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[1]);
reg = vec_perm(ret[0], ret[1], omask);
#elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int lsb0 = vec_sr(inp0, sh16);
__vector unsigned int lsb1 = vec_sr(inp1, sh16);
lsb0 = vec_and(lsb0, one);
lsb1 = vec_and(lsb1, one);
__vector unsigned int rnd0 = vec_add(lsb0, bias);
__vector unsigned int rnd1 = vec_add(lsb1, bias);
inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1);
__vector __bool int sel0 =
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 =
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp0 = vec_sr(inp0, sh16);
inp1 = vec_sr(inp1, sh16);
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
#endif
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
#ifdef _ARCH_PWR10
__vector signed short ret[4];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[0]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[1]);
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[2]);
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[3]);
reg.val[0] = vec_perm(ret[0], ret[1], omask);
reg.val[1] = vec_perm(ret[2], ret[3], omask);
#elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
__vector unsigned int lsb0 = vec_sr(inp0, sh16);
__vector unsigned int lsb1 = vec_sr(inp1, sh16);
__vector unsigned int lsb2 = vec_sr(inp2, sh16);
__vector unsigned int lsb3 = vec_sr(inp3, sh16);
lsb0 = vec_and(lsb0, one);
lsb1 = vec_and(lsb1, one);
lsb2 = vec_and(lsb2, one);
lsb3 = vec_and(lsb3, one);
__vector unsigned int rnd0 = vec_add(lsb0, bias);
__vector unsigned int rnd1 = vec_add(lsb1, bias);
__vector unsigned int rnd2 = vec_add(lsb2, bias);
__vector unsigned int rnd3 = vec_add(lsb3, bias);
inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1);
inp2 = vec_add(inp2, rnd2);
inp3 = vec_add(inp3, rnd3);
__vector __bool int sel0 =
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 =
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
__vector __bool int sel2 =
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
__vector __bool int sel3 =
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2);
inp3 = vec_sel(inp3, nan, sel3);
inp0 = vec_sr(inp0, sh16);
inp1 = vec_sr(inp1, sh16);
inp2 = vec_sr(inp2, sh16);
inp3 = vec_sr(inp3, sh16);
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
#endif
}
inline void prefetch(const void* addr) {
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
}
}; // namespace vec_op
#endif

954
csrc/cpu/cpu_types_vxe.hpp Normal file
View File

@@ -0,0 +1,954 @@
#ifndef CPU_TYPES_VXE_HPP
#define CPU_TYPES_VXE_HPP
#include <vecintrin.h>
#include <cmath>
#include <limits>
#include <torch/all.h>
namespace vec_op {
#define vec_neg(a) (-(a))
#define vec_add(a, b) ((a) + (b))
#define vec_sub(a, b) ((a) - (b))
#define vec_mul(a, b) ((a) * (b))
#define vec_div(a, b) ((a) / (b))
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
typedef struct ss16x8x2_t {
__vector signed short val[2];
} ss16x8x2_t;
typedef struct ss16x8x4_t {
__vector signed short val[4];
} ss16x8x4_t;
typedef struct f32x4x2_t {
__vector float val[2];
} f32x4x2_t;
typedef struct f32x4x4_t {
__vector float val[4];
} f32x4x4_t;
struct FP32Vec8;
struct FP32Vec16;
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit BF16Vec16(const void* ptr) {
// Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
};
const static __vector signed short zero = vec_splats((signed short)0);
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8& vec8_data)
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__vector float reg;
float values[VEC_ELEM_NUM];
};
__vector float reg;
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
f32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x2_t reg;
explicit FP32Vec8(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
}
explicit FP32Vec8() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
}
explicit FP32Vec8(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
}
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
}
explicit FP32Vec8(const BF16Vec8& v) {
// On big-endian s390x, place BF16 first to get correct byte order
reg.val[0] = (__vector float)vec_mergeh(v.reg, zero);
reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
f32x4x2_t out;
const __vector float log2e = vec_splats(1.44269504088896341f);
const __vector float one = vec_splats(1.0f);
const __vector float min_x = vec_splats(-87.3f);
const __vector float max_x = vec_splats(88.7f);
// 5th-degree minimax polynomial for 2^r (r in [0,1))
const __vector float c1 = vec_splats(0.6931471805599453f);
const __vector float c2 = vec_splats(0.240226506959101f);
const __vector float c3 = vec_splats(0.05550410866482158f);
const __vector float c4 = vec_splats(0.009618129107628477f);
const __vector float c5 = vec_splats(0.0013333558146428443f);
for (int i = 0; i < 2; i++) {
__vector float x = reg.val[i];
x = vec_max(x, min_x);
x = vec_min(x, max_x);
__vector float y = vec_mul(x, log2e);
__vector float kf = vec_floor(y);
__vector float r = vec_sub(y, kf);
__vector signed int k = vec_signed(kf);
const __vector signed int min_k = vec_splats((signed int)-126);
const __vector signed int max_k = vec_splats((signed int)127);
k = vec_min(vec_max(k, min_k), max_k);
// Build 2^k from exponent bits
__vector signed int exp_int = vec_add(k, vec_splats((signed int)127));
__vector unsigned int bits = (__vector unsigned int)exp_int;
bits = vec_sl(bits, vec_splats((unsigned int)23));
__vector float pow2k = (__vector float)bits;
// Improved minimax polynomial
__vector float poly = vec_madd(c5, r, c4);
poly = vec_madd(poly, r, c3);
poly = vec_madd(poly, r, c2);
poly = vec_madd(poly, r, c1);
poly = vec_madd(poly, r, one);
out.val[i] = vec_mul(pow2k, poly);
}
return FP32Vec8(out);
}
FP32Vec8 tanh() const {
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
const __vector float one = vec_splats(1.0f);
const __vector float two = vec_splats(2.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float sat =
vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
f32x4x2_t out;
for (int i = 0; i < 2; i++) {
__vector float x = reg.val[i];
__vector float ax = vec_abs(x);
// sign(x): +1 or -1
__vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
// saturation mask: |x| > sat
__vector __bool int saturated = vec_cmpgt(ax, sat);
// 2x
__vector float two_x = vec_mul(x, two);
// Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
f32x4x2_t tmp;
tmp.val[0] = two_x;
tmp.val[1] = two_x;
FP32Vec8 exp_2x_vec(tmp);
FP32Vec8 e2x = exp_2x_vec.exp();
__vector float e = e2x.reg.val[i];
// tanh(x) = (e - 1) / (e + 1)
__vector float num = vec_sub(e, one);
__vector float den = vec_add(e, one);
__vector float t = vec_div(num, den);
// For large |x|, clamp to sign(x)
out.val[i] = vec_sel(t, sign, saturated);
}
return FP32Vec8(out);
}
FP32Vec8 er() const {
// A&S 7.1.26 approximation:
// erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
// exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
const __vector float one = vec_splats(1.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float p = vec_splats(0.3275911f);
// Polynomial coeffs
const __vector float a1 = vec_splats(0.254829592f);
const __vector float a2 = vec_splats(-0.284496736f);
const __vector float a3 = vec_splats(1.421413741f);
const __vector float a4 = vec_splats(-1.453152027f);
const __vector float a5 = vec_splats(1.061405429f);
// Threshold where erf(x) ~ sign(x)
const __vector float sat = vec_splats(6.0f);
f32x4x2_t out;
for (int lane = 0; lane < 2; lane++) {
__vector float x = reg.val[lane];
__vector float ax = vec_abs(x);
// sign(x)
__vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
// |x| > 6 → erf(x) = ±1
__vector __bool int saturated = vec_cmpgt(ax, sat);
// t = 1 / (1 + p * |x|)
__vector float t = vec_madd(p, ax, one);
t = vec_div(one, t);
// poly = a5
__vector float poly = a5;
poly = vec_madd(poly, t, a4);
poly = vec_madd(poly, t, a3);
poly = vec_madd(poly, t, a2);
poly = vec_madd(poly, t, a1);
// full polynomial: poly = poly * t
poly = vec_mul(poly, t);
// Compute exp(-x^2)
__vector float x2 = vec_mul(x, x);
__vector float neg_x2 = vec_neg(x2);
f32x4x2_t tmp;
tmp.val[0] = neg_x2;
tmp.val[1] = neg_x2;
FP32Vec8 exp_neg_x2(tmp);
FP32Vec8 e = exp_neg_x2.exp();
__vector float ex = e.reg.val[lane];
// erf(x) = sign * (1 - poly * exp(-x^2))
__vector float term = vec_mul(poly, ex);
__vector float y = vec_sub(one, term);
y = vec_mul(y, sign);
// saturated → ±1
__vector float sat_val = vec_mul(sign, one);
out.val[lane] = vec_sel(y, sat_val, saturated);
}
return FP32Vec8(out);
}
// Elementwise sigmoid(x) = 1 / (1 + exp(-x))
FP32Vec8 sigmoid() const {
const __vector float one = vec_splats(1.0f);
f32x4x2_t neg;
for (int i = 0; i < 2; ++i) {
neg.val[i] = vec_neg(reg.val[i]);
}
FP32Vec8 neg_x(neg);
FP32Vec8 e = neg_x.exp(); // exp(-x)
f32x4x2_t denom;
for (int i = 0; i < 2; ++i) {
denom.val[i] = vec_add(one, e.reg.val[i]);
}
FP32Vec8 denom_vec(denom);
FP32Vec8 one_vec(1.0f);
return one_vec / denom_vec;
}
// Tanh-based GELU:
// gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
FP32Vec8 gelu_tanh() const {
const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π)
const __vector float k_0_0447 = vec_splats(0.044715f);
f32x4x2_t x2, x3, inner;
for (int i = 0; i < 2; ++i) {
__vector float x = reg.val[i];
x2.val[i] = vec_mul(x, x); // x^2
x3.val[i] = vec_mul(x2.val[i], x); // x^3
__vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3
inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...)
}
FP32Vec8 inner_vec(inner);
FP32Vec8 t = inner_vec.tanh(); // tanh part
FP32Vec8 one_vec(1.0f);
FP32Vec8 half_vec(0.5f);
FP32Vec8 x_vec(*this);
return x_vec * half_vec * (one_vec + t);
}
// Erf-based GELU:
// gelu(x) = 0.5 * x * (1 + erf(x / √2))
FP32Vec8 gelu_erf() const {
const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2
FP32Vec8 x_vec(*this);
f32x4x2_t scaled;
for (int i = 0; i < 2; ++i) {
scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2);
}
FP32Vec8 x_scaled(scaled);
FP32Vec8 erf_x = x_scaled.er();
FP32Vec8 one_vec(1.0f);
FP32Vec8 half_vec(0.5f);
return x_vec * half_vec * (one_vec + erf_x);
}
// Elementwise reciprocal: 1/x (scalar per lane, for correctness)
FP32Vec8 rcp() const {
AliasReg in, out;
in.reg = reg;
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
out.values[i] = 1.0f / in.values[i];
}
return FP32Vec8(out.reg);
}
// Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
FP32Vec8 rsqrt() const {
AliasReg in, out;
in.reg = reg;
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
out.values[i] = 1.0f / std::sqrt(in.values[i]);
}
return FP32Vec8(out.reg);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
f32x4x4_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x4_t reg;
explicit FP32Vec16(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
reg.val[2] = vec_splats(v);
reg.val[3] = vec_splats(v);
}
explicit FP32Vec16() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
reg.val[2] = vec_splats(0.0f);
reg.val[3] = vec_splats(0.0f);
}
explicit FP32Vec16(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr);
reg.val[3] = vec_xl(48, ptr);
}
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3];
}
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
}
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const BF16Vec16& v) {
// On big-endian s390x, place BF16 first to get correct byte order
reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero);
reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero);
reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero);
reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[3], b.reg.val[3])}));
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
float result = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&result, &start, ar](int i) { result += ar.values[start + i]; });
return result;
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
float reduce_max() const {
AliasReg ar;
ar.reg = reg;
float result = ar.values[0];
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) {
if (ar.values[i] > result) result = ar.values[i];
});
return result;
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr);
}
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
using FP16Vec16 = FP32Vec16;
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
// value.
};
} // namespace c10
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
// intrinsics
// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_madd(a.reg, b.reg, acc.reg);
}
// FP32Vec8 FMA: acc = acc + (a * b)
FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
// FP32Vec16 FMA: acc = acc + (a * b)
FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Multiply-Subtract: acc = acc - (a * b)
FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_msub(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Negative Multiply-Add: acc = -(a * b) + acc
FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_nmadd(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Negative Multiply-Subtract: acc = -(a * b) - acc
FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_nmsub(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
0x00007fff};
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
0x7fc00000};
const static __vector unsigned int sh16 = {16, 16, 16, 16};
const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int lsb0 = inp0 >> sh16;
__vector unsigned int lsb1 = inp1 >> sh16;
lsb0 = lsb0 & one;
lsb1 = lsb1 & one;
__vector unsigned int rnd0 = lsb0 + bias;
__vector unsigned int rnd1 = lsb1 + bias;
inp0 = inp0 + rnd0;
inp1 = inp1 + rnd1;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp0 = inp0 >> sh16;
inp1 = inp1 >> sh16;
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
__vector unsigned int lsb0 = inp0 >> sh16;
__vector unsigned int lsb1 = inp1 >> sh16;
__vector unsigned int lsb2 = inp2 >> sh16;
__vector unsigned int lsb3 = inp3 >> sh16;
lsb0 = lsb0 & one;
lsb1 = lsb1 & one;
lsb2 = lsb2 & one;
lsb3 = lsb3 & one;
__vector unsigned int rnd0 = lsb0 + bias;
__vector unsigned int rnd1 = lsb1 + bias;
__vector unsigned int rnd2 = lsb2 + bias;
__vector unsigned int rnd3 = lsb3 + bias;
inp0 = inp0 + rnd0;
inp1 = inp1 + rnd1;
inp2 = inp2 + rnd2;
inp3 = inp3 + rnd3;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel2 =
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2);
inp3 = vec_sel(inp3, nan, sel3);
inp0 = inp0 >> sh16;
inp1 = inp1 >> sh16;
inp2 = inp2 >> sh16;
inp3 = inp3 >> sh16;
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
// 1D softmax over `n` elements in `input`, writes result to `output`.
// Uses FP32Vec8 for main body, scalar tail handling.
// Requirement: n > 0
FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) {
if (n <= 0) return;
// ---------- Pass 1: find max ----------
float max_val = -std::numeric_limits<float>::infinity();
int i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 v(input + i);
FP32Vec8::AliasReg ar;
ar.reg = v.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
if (ar.values[j] > max_val) max_val = ar.values[j];
}
}
for (; i < n; ++i) {
if (input[i] > max_val) max_val = input[i];
}
// ---------- Pass 2: compute exp(x - max) and sum ----------
float sum = 0.0f;
i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
float tmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
tmp[j] = input[i + j] - max_val;
}
FP32Vec8 v(tmp);
FP32Vec8 e = v.exp();
FP32Vec8::AliasReg ar;
ar.reg = e.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
output[i + j] = ar.values[j];
sum += ar.values[j];
}
}
// Tail
for (; i < n; ++i) {
float x = input[i] - max_val;
float ex = std::exp(x); // scalar tail
output[i] = ex;
sum += ex;
}
// ---------- Pass 3: normalize ----------
float inv_sum = 1.0f / sum;
i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
float tmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
tmp[j] = output[i + j] * inv_sum;
}
FP32Vec8 v(tmp);
v.save(output + i);
}
for (; i < n; ++i) {
output[i] *= inv_sum;
}
}
// 1D RMSNorm kernel:
// input: x[0..n-1]
// weight: w[0..n-1] (gamma), may be nullptr
// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
// eps: small epsilon for numerical stability
FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input,
const float* weight, int n, float eps) {
if (n <= 0) return;
// ---------- Pass 1: compute sum of squares ----------
float sum_sq = 0.0f;
int i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
FP32Vec8 sq = x_vec * x_vec;
FP32Vec8::AliasReg ar;
ar.reg = sq.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
sum_sq += ar.values[j];
}
}
// Tail
for (; i < n; ++i) {
float v = input[i];
sum_sq += v * v;
}
float mean_sq = sum_sq / static_cast<float>(n);
float inv_rms = 1.0f / std::sqrt(mean_sq + eps);
// ---------- Pass 2: scale (and apply weight if given) ----------
const float inv_rms_f = inv_rms;
i = 0;
if (weight) {
// with gamma
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
float wtmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
wtmp[j] = weight[i + j];
}
FP32Vec8 w_vec(wtmp);
FP32Vec8 scale_vec(inv_rms_f);
FP32Vec8 y = x_vec * scale_vec * w_vec;
y.save(output + i);
}
for (; i < n; ++i) {
output[i] = input[i] * inv_rms_f * weight[i];
}
} else {
// without gamma
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
FP32Vec8 scale_vec(inv_rms_f);
FP32Vec8 y = x_vec * scale_vec;
y.save(output + i);
}
for (; i < n; ++i) {
output[i] = input[i] * inv_rms_f;
}
}
}
// Prefetch data to cache for better memory access performance
FORCE_INLINE void prefetch(const void* addr) {
__builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality
}
}; // namespace vec_op
#endif

794
csrc/cpu/cpu_types_x86.hpp Normal file
View File

@@ -0,0 +1,794 @@
#ifndef CPU_TYPES_X86_HPP
#define CPU_TYPES_X86_HPP
#include <immintrin.h>
#include <torch/all.h>
#ifndef __AVX2__
static_assert(false, "AVX2 must be supported for the current implementation.");
#endif
namespace vec_op {
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
#define CPU_KERNEL_GUARD_OUT(NAME)
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
// Function to get the timestamp using RDTSCP
FORCE_INLINE uint64_t bench_timestamp() {
unsigned int cycles_low, cycles_high;
asm volatile(
".intel_syntax noprefix\n\t"
"CPUID\n\t" // Serialize instruction stream to ensure previous
// instructions complete
"RDTSCP\n\t" // Read TSC and core ID
"mov %0, edx\n\t" // Store high 32 bits of TSC
"mov %1, eax\n\t" // Store low 32 bits of TSC
".att_syntax"
: "=r"(cycles_high), "=r"(cycles_low)::"rax", "rbx", "rcx",
"rdx" // Clobbered registers
);
return (uint64_t)cycles_high << 32 | cycles_low;
}
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
struct FP32Vec8;
struct FP32Vec16;
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128i reg;
explicit FP16Vec8(const void* ptr)
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
__m256i reg;
// normal load
explicit FP16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
// non-temporal load
explicit FP16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm256_mask_storeu_epi16(ptr, mask, reg);
}
};
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128i reg;
explicit BF16Vec8(const void* ptr)
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
__m256i reg;
// normal load
explicit BF16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
// non-temporal load
explicit BF16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm256_mask_storeu_epi16(ptr, mask, reg);
}
};
#ifdef __AVX512F__
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m512i reg;
explicit BF16Vec32() : reg(_mm512_setzero_si512()) {}
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
explicit BF16Vec32(__m512i data) : reg(data) {}
explicit BF16Vec32(BF16Vec8& vec8_data)
: reg((__m512i)_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
(__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1),
(__m128i)vec8_data.reg, 2),
(__m128i)vec8_data.reg, 3)) {}
void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; }
};
#else
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m256i reg_low;
__m256i reg_high;
explicit BF16Vec32(const void* ptr)
: reg_low(_mm256_loadu_si256((__m256i const*)ptr)),
reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {}
explicit BF16Vec32(__m256i low, __m256i high)
: reg_low(low), reg_high(high) {}
explicit BF16Vec32(BF16Vec8& vec8_data)
: reg_low((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)),
reg_high((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)) {}
void save(void* ptr) const {
_mm256_storeu_si256((__m256i*)ptr, reg_low);
_mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
}
};
#endif
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__m128 reg;
float values[VEC_ELEM_NUM];
};
__m128 reg;
explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {}
explicit FP32Vec4(__m128 data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
__m256 reg;
float values[VEC_ELEM_NUM];
};
__m256 reg;
explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {}
explicit FP32Vec8(__m256 data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}
explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {}
explicit FP32Vec8(const BF16Vec8& v)
: reg(_mm256_castsi256_ps(
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
expf(ar.values[5]), expf(ar.values[4]),
expf(ar.values[3]), expf(ar.values[2]),
expf(ar.values[1]), expf(ar.values[0])));
}
FP32Vec8 tanh() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
tanhf(ar.values[5]), tanhf(ar.values[4]),
tanhf(ar.values[3]), tanhf(ar.values[2]),
tanhf(ar.values[1]), tanhf(ar.values[0])));
}
FP32Vec8 er() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
erf(ar.values[5]), erf(ar.values[4]),
erf(ar.values[3]), erf(ar.values[2]),
erf(ar.values[1]), erf(ar.values[0])));
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(_mm256_add_ps(reg, b.reg));
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(_mm256_div_ps(reg, b.reg));
}
void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); }
};
#ifdef __AVX512F__
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512i reg;
int32_t values[VEC_ELEM_NUM];
};
__m512i reg;
explicit INT32Vec16(const void* data_ptr)
: reg(_mm512_loadu_epi32(data_ptr)) {}
void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); }
void save(int32_t* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm512_mask_storeu_epi32(ptr, mask, reg);
}
};
#endif
#ifdef __AVX512F__
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512 reg;
float values[VEC_ELEM_NUM];
};
__m512 reg;
explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
// normal load
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
// non-temporal load
explicit FP32Vec16(bool, void* ptr)
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {}
// de-pack 4 bit values
explicit FP32Vec16(int64_t value, const FP32Vec16& lut) {
int64_t mask_0 = 0x0F0F0F0F0F0F0F0F;
int64_t mask_1 = 0xF0F0F0F0F0F0F0F0;
int64_t value_0 = value & mask_0;
int64_t value_1 = value & mask_1;
__m128i vec_0 = _mm_movpi64_epi64((__m64)value_0);
__m128i vec_1 = _mm_movpi64_epi64((__m64)value_1);
vec_0 = _mm_cvtepu8_epi16(vec_0);
vec_1 = _mm_cvtepu8_epi16(vec_1);
vec_1 = _mm_slli_epi16(vec_1, 4);
__m128i vec = _mm_or_si128(vec_0, vec_1);
__m512i vec_i32 = _mm512_cvtepu8_epi32(vec);
reg = _mm512_permutexvar_ps(vec_i32, lut.reg);
}
explicit FP32Vec16(const FP32Vec4& data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
(__m128i)data.reg, 1),
(__m128i)data.reg, 2),
(__m128i)data.reg, 3)) {}
explicit FP32Vec16(const FP32Vec8& data)
: reg((__m512)_mm512_inserti32x8(
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
explicit FP32Vec16(const BF16Vec16& v)
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16& v)
: reg(_mm512_cvt_roundepi32_ps(
v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
}
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(_mm512_add_ps(reg, b.reg));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg));
}
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg)));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(_mm512_max_ps(reg, b.reg));
}
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(_mm512_min_ps(reg, b.reg));
}
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
}
FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); }
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
float reduce_max() const { return _mm512_reduce_max_ps(reg); }
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
void save(float* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm512_mask_storeu_ps(ptr, mask, reg);
}
};
#else
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m256 reg;
float values[8];
};
__m256 reg_low;
__m256 reg_high;
explicit FP32Vec16(float v)
: reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {}
explicit FP32Vec16()
: reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {}
explicit FP32Vec16(const float* ptr)
: reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {}
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
explicit FP32Vec16(const FP32Vec4& data)
: reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
reg_high((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {}
explicit FP32Vec16(const FP32Vec8& data)
: reg_low(data.reg), reg_high(data.reg) {}
explicit FP32Vec16(const FP16Vec16& v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1);
reg_low = _mm256_cvtph_ps(low);
reg_high = _mm256_cvtph_ps(high);
}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const BF16Vec16& v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1);
__m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
__m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
__m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
__m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
reg_low = _mm256_castsi256_ps(v_low_shifted);
reg_high = _mm256_castsi256_ps(v_high_shifted);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
_mm256_mul_ps(reg_high, b.reg_high));
}
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
_mm256_add_ps(reg_high, b.reg_high));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
_mm256_sub_ps(reg_high, b.reg_high));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
_mm256_div_ps(reg_high, b.reg_high));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(_mm256_max_ps(reg_low, b.reg_low),
_mm256_max_ps(reg_high, b.reg_high));
}
float reduce_max() const {
__m256 v = _mm256_max_ps(reg_low, reg_high);
// Permute to compare elements within 128-bit lanes
__m256 v_shuffled = _mm256_permute_ps(
v, 0b00001011); // Swap halves within each 128-bit lane
__m256 v_max = _mm256_max_ps(v, v_shuffled);
v_shuffled = _mm256_permute_ps(
v_max, 0b00000001); // Shuffle elements within each 128-bit lane
v_max = _mm256_max_ps(v_max, v_shuffled);
// Permute to compare elements between 128-bit lanes
v_shuffled =
_mm256_permute2f128_ps(v_max, v_max, 0b00000001); // Swap 128-bit lanes
v_max = _mm256_max_ps(v_max, v_shuffled);
// At this point, the maximum value is present in all elements of v_max.
// Extract the first element for the scalar result.
return _mm256_cvtss_f32(v_max); // Extract the lowest 32-bit float
}
float reduce_sum() const {
FP32Vec8 low = FP32Vec8(reg_low);
FP32Vec8 high = FP32Vec8(reg_high);
return low.reduce_sum() + high.reduce_sum();
}
template <int group_size>
float reduce_sub_sum(int idx) {
float sum = 0.0;
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
uint32_t mask = base_mask << (idx * group_size);
AliasReg ar;
auto func = [&sum, &mask, &ar](int i) {
int flag = mask & 0x1;
mask = mask >> 1;
if (flag != 0) sum += ar.values[i];
};
ar.reg = reg_low;
unroll_loop<int, 8>(func);
ar.reg = reg_high;
unroll_loop<int, 8>(func);
return sum;
}
void save(float* ptr) const {
_mm256_storeu_ps(ptr, reg_low);
_mm256_storeu_ps(ptr + 8, reg_high);
}
};
#endif
#ifdef __AVX512F__
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m128i reg;
int8_t values[VEC_ELEM_NUM];
};
__m128i reg;
explicit INT8Vec16(const FP32Vec16& vec)
: reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(
vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {}
void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); }
void save(int8_t* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm_mask_storeu_epi8(ptr, mask, reg);
}
};
struct INT8Vec64 : public Vec<INT8Vec64> {
constexpr static int VEC_ELEM_NUM = 64;
union AliasReg {
__m512i reg;
int8_t values[VEC_ELEM_NUM];
};
__m512i reg;
// normal load
explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {}
// non-temporal load
explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {}
void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); }
void save(int8_t* ptr, const int elem_num) const {
constexpr uint64_t M = 0xFFFFFFFFFFFFFFFF;
__mmask64 mask = _cvtu64_mask64(M >> (64 - elem_num));
_mm512_mask_storeu_epi8(ptr, mask, reg);
}
// non-temporal save
void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); }
};
#endif
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b;
}
template <>
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*reinterpret_cast<unsigned short*>(ptr) =
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
}
inline FP16Vec8::FP16Vec8(const FP32Vec8& v)
: reg(_mm256_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
#ifdef __AVX512F__
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
: reg(_mm512_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
#else
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
: reg(_mm256_insertf128_si256(
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
FP16Vec8(FP32Vec8(v.reg_high)).reg, 1)) {}
#endif
#ifdef __AVX512BF16__
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
}
#else
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
#ifdef __AVX512F__
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#else
namespace {
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
__m256i ai = _mm256_castps_si256(a);
ai = _mm256_srli_epi32(ai, 16);
ai = _mm256_packus_epi32(ai, ai);
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
return _mm256_extracti128_si256(ai, 0);
}
} // namespace
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
}
#endif // __AVX512F__
#endif // __AVX512BF16__
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
#ifdef __AVX512F__
inline void non_temporal_save(FP16Vec16& vec, void* ptr) {
_mm256_stream_si256((__m256i*)ptr, vec.reg);
}
inline void non_temporal_save(BF16Vec32& vec, void* ptr) {
_mm512_stream_si512((__m512i*)ptr, vec.reg);
}
inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
_mm256_stream_si256((__m256i*)ptr, vec.reg);
}
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
_mm512_stream_ps((float*)ptr, vec.reg);
}
static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1,
void* ptr) {
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
vec_1 = _mm512_slli_epi32(vec_1, 16);
vec_0 = _mm512_or_si512(vec_0, vec_1);
_mm512_storeu_epi32(ptr, vec_0);
}
static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1,
void* ptr) {
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
vec_1 = _mm512_slli_epi32(vec_1, 16);
vec_0 = _mm512_or_si512(vec_0, vec_1);
_mm512_storeu_epi32(ptr, vec_0);
}
#endif
inline void mem_barrier() { _mm_mfence(); }
}; // namespace vec_op
#endif

402
csrc/cpu/cpu_wna16.cpp Normal file
View File

@@ -0,0 +1,402 @@
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
#endif
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
#define VLLM_DISPATCH_CASE_16B_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__))
template <typename T>
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
namespace {
using cpu_utils::ISA;
using cpu_utils::VecTypeTrait;
template <typename scalar_t, ISA isa, bool has_zp, bool use_desc_act>
class Dequantizer4b {
public:
constexpr static int32_t pack_num = 32 / 4;
using scalar_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
public:
static void dequant(int32_t* __restrict__ q_weight,
scalar_t* __restrict__ weight,
scalar_t* __restrict__ scales,
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
const int64_t scales_stride, const int64_t zeros_stride,
const int32_t k_size, const int32_t group_size) {
vec_op::FP32Vec16 lut;
if constexpr (has_zp) {
// AWQ
alignas(64) static const float LUT[16] = {
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
lut = vec_op::FP32Vec16(LUT);
} else {
// GPTQ
alignas(64) static const float LUT[16] = {
-8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
lut = vec_op::FP32Vec16(LUT);
}
// per 64-bits elem contains 16 output channels
int64_t* __restrict__ curr_q_weight = reinterpret_cast<int64_t*>(q_weight);
int64_t* __restrict__ curr_zeros = reinterpret_cast<int64_t*>(zeros);
scalar_t* __restrict__ curr_weight = weight;
scalar_t* __restrict__ curr_scale = scales;
vec_op::FP32Vec16 scale_0;
vec_op::FP32Vec16 scale_1;
vec_op::FP32Vec16 zero_0;
vec_op::FP32Vec16 zero_1;
int32_t group_counter = 0;
for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) {
int64_t qwb_0 = *curr_q_weight;
int64_t qwb_1 = *(curr_q_weight + 1);
vec_op::FP32Vec16 wb_0(qwb_0, lut);
vec_op::FP32Vec16 wb_1(qwb_1, lut);
if constexpr (!use_desc_act) {
if (group_counter == 0) {
scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale));
scale_1 = vec_op::FP32Vec16(scale_0);
curr_scale += scales_stride;
if constexpr (has_zp) {
zero_0 = vec_op::FP32Vec16(*curr_zeros, lut);
zero_1 = vec_op::FP32Vec16(zero_0);
curr_zeros += zeros_stride / 2;
}
}
} else {
int32_t g_idx_0 = g_idx[k_idx];
int32_t g_idx_1 = g_idx[k_idx + 1];
scale_0 = vec_op::FP32Vec16(
scalar_vec_t(curr_scale + g_idx_0 * scales_stride));
scale_1 = vec_op::FP32Vec16(
scalar_vec_t(curr_scale + g_idx_1 * scales_stride));
if constexpr (has_zp) {
zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2),
lut);
zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2),
lut);
}
}
if constexpr (has_zp) {
wb_0 = wb_0 - zero_0;
wb_1 = wb_1 - zero_1;
}
wb_0 = wb_0 * scale_0;
wb_1 = wb_1 * scale_1;
scalar_vec_t output_vec_0(wb_0);
scalar_vec_t output_vec_1(wb_1);
// AMX needs to interlave K elements to pack as 32 bits
if constexpr (isa == ISA::AMX) {
vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight);
} else {
output_vec_0.save(curr_weight);
output_vec_1.save(curr_weight + 16);
}
// update
curr_q_weight += 2;
curr_weight += 32;
if constexpr (!use_desc_act) {
group_counter += 2;
if (group_counter == group_size) {
group_counter = 0;
}
}
}
}
};
}; // namespace
template <typename scalar_t, typename dequantizer_t, typename gemm_t>
void cpu_gemm_wna16_impl(
scalar_t* __restrict__ input, int32_t* __restrict__ q_weight,
scalar_t* __restrict__ output, scalar_t* __restrict__ scales,
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size,
const int32_t k_size, const int64_t input_stride,
const int64_t output_stride, const int64_t scales_group_stride,
const int64_t zeros_group_stride, const int32_t group_num,
const int32_t group_size, const int64_t pack_factor) {
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
constexpr int32_t n_block_size = 16;
static_assert(gemm_n_tile_size % n_block_size == 0);
const int32_t thread_num = omp_get_max_threads();
// a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks
const int32_t n_partition_size = [&]() {
const int64_t cache_size = cpu_utils::get_l2_size();
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
int64_t ps_thread_limit = n_size / thread_num;
ps_cache_limit =
std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size,
(int64_t)gemm_n_tile_size);
ps_thread_limit =
std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size,
(int64_t)gemm_n_tile_size);
return std::min(ps_cache_limit, ps_thread_limit);
}();
const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size;
// get buffer size
const int64_t b_buffer_size =
(((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64);
const int64_t c_buffer_size =
(((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64);
const int64_t b_buffer_offset = 0;
const int64_t c_buffer_offset = b_buffer_size;
const int64_t buffer_size = b_buffer_size + c_buffer_size;
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
thread_num);
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
scalar_t* __restrict__ b_buffer = nullptr;
float* __restrict__ c_buffer = nullptr;
{
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
->get_data<uint8_t>() +
thread_id * buffer_size;
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
}
const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size;
const int64_t b_buffer_block_stride = n_block_size * k_size;
const int32_t zeros_block_stride = n_block_size / pack_factor;
gemm_t gemm;
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
const int32_t n_start_idx = task_id * n_partition_size;
const int32_t n_block_start_idx = n_start_idx / n_block_size;
const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx);
const int32_t n_block_num = n_num / n_block_size;
// std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n",
// thread_id, task_id, n_start_idx, n_num);
// dequant weight
{
int32_t* __restrict__ curr_q_weight =
q_weight + n_block_start_idx * q_weight_block_stride;
scalar_t* __restrict__ curr_b_buffer = b_buffer;
scalar_t* __restrict__ curr_scales = scales + n_start_idx;
int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor;
for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) {
dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales,
curr_zeros, g_idx, scales_group_stride,
zeros_group_stride, k_size, group_size);
// if (block_idx == 0 && n_start_idx == 0) {
// print_logits("depacked weight", curr_b_buffer, k_size,
// n_block_size, n_block_size);
// }
// update
curr_q_weight += q_weight_block_stride;
curr_b_buffer += b_buffer_block_stride;
curr_scales += n_block_size;
curr_zeros += zeros_block_stride;
}
}
// compute loop
{
const int32_t n_tile_num = n_num / gemm_n_tile_size;
scalar_t* __restrict__ curr_input = input;
scalar_t* __restrict__ init_bias = bias;
if (bias != nullptr) {
init_bias += n_start_idx;
}
scalar_t* __restrict__ init_output = output + n_start_idx;
for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) {
const int32_t curr_m_size =
std::min(gemm_m_tile_size, m_size - m_idx);
scalar_t* __restrict__ curr_b_buffer = b_buffer;
scalar_t* __restrict__ curr_bias = init_bias;
scalar_t* __restrict__ curr_output = init_output;
for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) {
gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size,
input_stride, b_buffer_block_stride, gemm_n_tile_size,
false);
if (bias != nullptr) {
cpu_micro_gemm::bias_epilogue<gemm_n_tile_size>(
c_buffer, curr_output, curr_bias, curr_m_size,
gemm_n_tile_size, output_stride);
curr_bias += gemm_n_tile_size;
} else {
cpu_micro_gemm::default_epilogue<gemm_n_tile_size>(
c_buffer, curr_output, curr_m_size, gemm_n_tile_size,
output_stride);
}
curr_b_buffer +=
b_buffer_block_stride * (gemm_n_tile_size / n_block_size);
curr_output += gemm_n_tile_size;
}
curr_input += gemm_m_tile_size * input_stride;
init_output += gemm_m_tile_size * output_stride;
}
}
}
}
}
void cpu_gemm_wna16(
const torch::Tensor& input, // [M, K]
const torch::Tensor&
q_weight, // [N / 16, K * 16 / pack_factor], packed as int32
torch::Tensor& output, // [M, N]
const torch::Tensor& scales, // [group_num, N]
const std::optional<torch::Tensor>&
zeros, // [group_num, N / pack_factor], packed as int32
const std::optional<torch::Tensor>& g_idx, // [K]
const std::optional<torch::Tensor>& bias, // [N]
const int64_t pack_factor, const std::string& isa_hint) {
using cpu_utils::ISA;
TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits
const int32_t a_m_size = input.size(0);
const int32_t a_k_size = input.size(1);
const int64_t a_m_stride = input.stride(0);
const int32_t b_n_size = q_weight.size(0) * 16;
TORCH_CHECK_EQ(a_k_size % 32, 0);
TORCH_CHECK_EQ(b_n_size % 32, 0);
const int32_t group_num = scales.size(0);
const int32_t group_size = a_k_size / group_num;
TORCH_CHECK_EQ(group_size % 2, 0);
const int64_t scales_group_stride = scales.stride(0);
const int64_t output_m_stride = output.stride(0);
bool has_zp = zeros.has_value();
bool use_desc_act = g_idx.has_value();
TORCH_CHECK(!(has_zp && use_desc_act));
ISA isa = [&]() {
if (isa_hint == "amx") {
return ISA::AMX;
} else if (isa_hint == "vec") {
return ISA::VEC;
} else {
TORCH_CHECK(false, "unsupported isa hint: " + isa_hint);
}
}();
int32_t* zeros_ptr = has_zp ? zeros->data_ptr<int32_t>() : nullptr;
const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0;
int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr<int32_t>() : nullptr;
VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() {
if (isa == ISA::AMX) {
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::AMX, scalar_t>;
if (has_zp) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, true, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
if (use_desc_act) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, true>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
} else {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
} else if (isa == ISA::VEC) {
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::VEC, scalar_t>;
if (has_zp) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, true, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
if (use_desc_act) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, true>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
} else {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
}
});
}

569
csrc/cpu/dnnl_helper.cpp Normal file
View File

@@ -0,0 +1,569 @@
#include <list>
#include <optional>
#include "common/memory_desc.hpp"
#include "common/memory.hpp"
#include "dnnl_helper.h"
#include "scratchpad_manager.h"
static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
return engine;
}
static dnnl::stream& default_stream() {
static dnnl::stream stream(default_engine());
return stream;
}
void release_dnnl_matmul_handler(int64_t handler) {
DNNLMatMulPrimitiveHandler* ptr =
reinterpret_cast<DNNLMatMulPrimitiveHandler*>(handler);
delete ptr;
}
template <typename KT, typename VT>
class DNNLPrimitiveCache {
public:
using cache_value_t = std::pair<KT, VT>;
using result_value_t = VT;
using container_t = std::list<cache_value_t>;
using value_iterator_t = typename container_t::iterator;
using map_t = std::unordered_map<KT, value_iterator_t>;
using creator_t = VT (*)();
public:
DNNLPrimitiveCache(size_t capacity)
: capacity_(capacity),
values_(),
key_to_value_(std::min(256lu, capacity)) {
assert(capacity > 0);
}
template <typename F>
result_value_t get_or_create(const KT& key, F&& creator) {
std::optional<value_iterator_t> value = get_value(key);
if (value.has_value()) {
return value.value()->second;
} else {
return add_value({key, creator()})->second;
}
}
size_t size() const { return values_.size(); }
private:
void dump_data() {
std::stringstream ss;
ss << "table_id: " << std::hex << reinterpret_cast<size_t>(this) << std::dec
<< "\n";
ss << "container: [";
for (auto&& iter : values_) {
ss << "(" << iter.first << ", " << std::hex
<< reinterpret_cast<size_t>(iter.second.get()) << "), " << std::dec;
}
ss << "]\n";
ss << "map: [";
for (auto&& iter : key_to_value_) {
ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
<< reinterpret_cast<size_t>(iter.second->second.get()) << std::dec
<< "), ";
}
ss << "]\n";
std::printf("%s\n", ss.str().c_str());
}
value_iterator_t add_value(cache_value_t&& new_value) {
if (size() == capacity_) {
cache_value_t& last_item = values_.back();
key_to_value_.erase(last_item.first);
values_.pop_back();
}
auto& added_value_ = values_.emplace_front(std::move(new_value));
key_to_value_.emplace(added_value_.first, values_.begin());
return values_.begin();
}
std::optional<value_iterator_t> get_value(const KT& key) {
if (key_to_value_.size() > 0 && key == values_.begin()->first) {
return values_.begin();
}
auto value_map_iterator = key_to_value_.find(key);
if (value_map_iterator != key_to_value_.end()) {
values_.splice(values_.begin(), values_, value_map_iterator->second);
return value_map_iterator->second;
} else {
return {};
}
}
private:
const size_t capacity_;
container_t values_;
map_t key_to_value_;
};
DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
const Args& args, dnnl::memory::data_type b_type)
: b_n_size_(args.b_n_size),
b_n_stride_(args.b_n_stride),
b_k_size_(args.b_k_size),
b_k_stride_(args.b_k_stride),
b_type_(b_type),
c_type_(args.c_type),
runtime_memory_ptrs_(8),
primitive_cache_size_(args.primitive_cache_size) {
assert(primitive_cache_size_ > 0);
}
void DNNLMatMulPrimitiveHandler::prepack_weight(
void* original_b_ptr, dnnl::memory::desc original_b_md,
dnnl::memory::desc b_target_mem_desc) {
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
{
dnnl::reorder(original_weight, packed_weight)
.execute(default_stream(), original_weight, packed_weight);
default_stream().wait();
}
memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
b_target_mem_desc_ = b_target_mem_desc;
}
void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
size_t index, dnnl_memory* memory_ptr) {
dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
dnnl_memory_desc* mem_desc = const_cast<dnnl_memory_desc*>(memory_ptr->md());
runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
}
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
return runtime_memory_ptrs_[index];
}
namespace std {
template <>
struct hash<W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey> {
size_t operator()(
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
hash<int>()(static_cast<int>(val.a_qs)) ^
hash<int>()(static_cast<int>(val.b_qs)) ^ hash<bool>()(val.use_azp) ^
hash<int>()(static_cast<int>(val.c_type));
}
};
template <>
struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
size_t operator()(
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
return hash<dnnl_dim_t>()(val.a_m_size) ^ hash<bool>()(val.use_bias) ^
hash<int>()(static_cast<int>(val.bias_type));
}
};
template <>
struct hash<MatMulPrimitiveHandler::ClassMatmulCacheKey> {
size_t operator()(
const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
hash<int>()(static_cast<int>(val.b_type));
}
};
template <>
struct hash<MatMulPrimitiveHandler::MSizeCacheKey> {
size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const {
return hash<dnnl_dim_t>()(val.a_m_size) ^
hash<dnnl_dim_t>()(val.a_m_stride) ^ hash<bool>()(val.use_bias) ^
hash<int>()(static_cast<int>(val.bias_type));
}
};
} // namespace std
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
l.c_type == r.c_type;
}
bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
l.bias_type == r.bias_type;
}
bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
l.b_type == r.b_type;
}
bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
const MatMulPrimitiveHandler::MSizeCacheKey& r) {
return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride &&
l.use_bias == r.use_bias && l.bias_type == r.bias_type;
}
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
get_w8a8_class_primitive_cache(
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
int64_t cache_size) {
static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
assert(cache_size > 0);
return cache.get_or_create(key, [&]() {
return std::make_shared<W8A8MatMulPrimitiveHandler::MSizeCache>(cache_size);
});
}
W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
: DNNLMatMulPrimitiveHandler(
static_cast<const DNNLMatMulPrimitiveHandler::Args&>(args),
dnnl::memory::data_type::s8),
use_azp_(args.use_a_zero_point),
a_qs_(args.a_quantization_strategy),
b_qs_(args.b_quantization_strategy),
m_size_cache_(nullptr) {
assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
assert(!use_azp_);
};
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc(
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true)
.weights_desc());
init_runtime_memory_cache(args);
}
void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
a_storage->set_data_handle((void*)args.a_ptr);
a_mem_desc->dims[0] = args.a_m_size;
c_storage->set_data_handle((void*)args.c_ptr);
c_mem_desc->dims[0] = args.a_m_size;
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
}
if (use_azp_) {
auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
get_runtime_memory_ptr(3);
a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
}
if (args.use_bias) {
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
bias_storage->set_data_handle((void*)args.bias_ptr);
}
dnnl::matmul matmul = get_matmul_cache(args);
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
}
dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
const MSizeCacheKey& key) {
if (m_size_cache_.get() == nullptr) {
ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
.b_k_size = b_k_size_,
.a_qs = a_qs_,
.b_qs = b_qs_,
.use_azp = use_azp_,
.c_type = c_type_};
m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
}
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
}
void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
dnnl::memory::data_type::s8,
dnnl::memory::format_tag::ab},
default_engine(), nullptr);
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
memory_cache_[DNNL_ARG_DST] =
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
default_engine(), nullptr);
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
// For PER_TOKEN, scales will be applied in outside epilogue
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
{{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
set_runtime_memory_ptr(
2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
if (use_azp_) {
memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
{{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
set_runtime_memory_ptr(
3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
}
}
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
(void*)args.b_scales_ptr);
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), (void*)args.b_scales_ptr);
}
memory_cache_[DNNL_ARG_BIAS] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
memory_cache_[DNNL_ARG_SCRATCHPAD] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
const MSizeCacheKey& key, bool first_time) {
dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
dnnl::memory::data_type::s8,
dnnl::memory::format_tag::ab);
dnnl::memory::desc b_md;
if (first_time) {
b_md =
dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
dnnl::memory::format_tag::any);
} else {
b_md = b_target_mem_desc_;
}
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// For PER_TOKEN, scales will be applied in outside epilogue
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
attr.set_scales_mask(DNNL_ARG_SRC, 0);
if (use_azp_) {
attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
}
}
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
}
if (key.use_bias) {
// For PER_TOKEN, bias will be applied in epilogue
assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
c_md, attr);
} else {
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
attr);
}
}
MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
: DNNLMatMulPrimitiveHandler(
static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
m_size_cache_(nullptr) {
assert(b_type_ == dnnl::memory::data_type::f32 ||
b_type_ == dnnl::memory::data_type::bf16 ||
b_type_ == dnnl::memory::data_type::f16);
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc(
MSizeCacheKey{
#ifdef VLLM_USE_ACL
// Arm Compute Library (ACL) backend for oneDNN does
// not support runtime
// dimensions, so we set M to a default value
.a_m_size = 128,
.a_m_stride = b_k_size_,
#else
.a_m_size = DNNL_RUNTIME_DIM_VAL,
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
#endif
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true)
.weights_desc());
init_runtime_memory_cache(args);
}
static std::shared_ptr<MatMulPrimitiveHandler::MSizeCache>
get_matul_class_primitive_cache(
const MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
int64_t cache_size) {
static MatMulPrimitiveHandler::ClassMatmulCache cache(128);
assert(cache_size > 0);
return cache.get_or_create(key, [&]() {
return std::make_shared<MatMulPrimitiveHandler::MSizeCache>(cache_size);
});
}
void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
a_storage->set_data_handle((void*)args.a_ptr);
a_mem_desc->dims[0] = args.a_m_size;
a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride;
c_storage->set_data_handle((void*)args.c_ptr);
c_mem_desc->dims[0] = args.a_m_size;
#ifndef VLLM_USE_ACL
// We do not support in ACL backend of oneDNN, we handle bias by:
// 1. copying it into the result tensor
// 2. attaching a fused-sum post-op to the matmul primitive
if (args.use_bias) {
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
bias_storage->set_data_handle((void*)args.bias_ptr);
}
#endif
dnnl::matmul matmul = get_matmul_cache(args);
// With ACL backend of oneDNN, the required memory format might change when the
// source tensor dims change. This does not really happen in practice, so isn't
// a performance hit, but we need to support it because the API allows for it.
#ifdef VLLM_USE_ACL
auto new_expected_wei_desc =
dnnl::matmul::primitive_desc(
const_cast<dnnl_primitive_desc_t>(matmul.get_primitive_desc()))
.weights_desc();
if (new_expected_wei_desc != b_target_mem_desc_) {
prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(),
b_target_mem_desc_, new_expected_wei_desc);
}
#endif
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
}
dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
const MSizeCacheKey& key) {
if (m_size_cache_.get() == nullptr) {
ClassMatmulCacheKey class_key = {
.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_};
m_size_cache_ =
get_matul_class_primitive_cache(class_key, primitive_cache_size_);
}
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
}
dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
const MSizeCacheKey& key, bool first_time) {
dnnl::memory::desc a_md;
dnnl::memory::desc b_md;
if (first_time) {
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
dnnl::memory::format_tag::ab);
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
dnnl::memory::format_tag::any);
} else {
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
{key.a_m_stride, 1});
#ifdef VLLM_USE_ACL
// ACL's backend of oneDNN always expects the weight format to be "any"
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
dnnl::memory::format_tag::any);
#else
b_md = b_target_mem_desc_;
#endif
}
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
if (key.use_bias) {
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
// Since ACL's matmuls don't support passing a bias_md, we apply the bias
// through a fused-sum post-op
#ifdef VLLM_USE_ACL
dnnl::post_ops post_ops;
post_ops.append_sum();
attr.set_post_ops(post_ops);
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
attr);
#else
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
c_md, attr);
#endif
} else {
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
attr);
}
}
void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
memory_cache_[DNNL_ARG_SRC] = dnnl::memory(
{{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr);
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
memory_cache_[DNNL_ARG_DST] =
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
default_engine(), nullptr);
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
// ACL matmuls don't support bias_md, so we don't need these
#ifndef VLLM_USE_ACL
memory_cache_[DNNL_ARG_BIAS] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
#endif
memory_cache_[DNNL_ARG_SCRATCHPAD] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}
bool is_onednn_acl_supported() {
#ifdef VLLM_USE_ACL
return true;
#else
return false;
#endif
}

220
csrc/cpu/dnnl_helper.h Normal file
View File

@@ -0,0 +1,220 @@
#ifndef DNNL_HELPER_H
#define DNNL_HELPER_H
#include <optional>
#include <cassert>
#include "oneapi/dnnl/dnnl.hpp"
namespace c10 {
struct BFloat16;
struct Half;
} // namespace c10
namespace dnnl {
namespace impl {
struct memory_storage_t;
struct matmul_pd_t;
struct matmul_desc_t;
} // namespace impl
} // namespace dnnl
struct dnnl_memory_desc;
template <typename KT, typename VT>
class DNNLPrimitiveCache;
template <typename T>
struct DNNLType {
static constexpr dnnl::memory::data_type type =
dnnl::memory::data_type::undef;
};
template <>
struct DNNLType<int8_t> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
};
template <>
struct DNNLType<int32_t> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
};
template <>
struct DNNLType<float> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
};
template <>
struct DNNLType<c10::BFloat16> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
};
template <>
struct DNNLType<c10::Half> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
};
template <typename T>
constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType<std::decay_t<T>>::type;
}
class DNNLMatMulPrimitiveHandler {
public:
virtual ~DNNLMatMulPrimitiveHandler() = default;
protected:
struct Args {
dnnl_dim_t b_n_size;
dnnl_dim_t b_n_stride;
dnnl_dim_t b_k_size;
dnnl_dim_t b_k_stride;
void* b_ptr;
dnnl::memory::data_type c_type;
size_t primitive_cache_size;
};
protected:
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md,
dnnl::memory::desc b_target_mem_desc);
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
get_runtime_memory_ptr(size_t index);
protected:
const dnnl_dim_t b_n_size_;
const dnnl_dim_t b_n_stride_;
const dnnl_dim_t b_k_size_;
const dnnl_dim_t b_k_stride_;
dnnl::memory::data_type b_type_;
dnnl::memory::data_type c_type_;
std::unordered_map<int, dnnl::memory> memory_cache_;
std::vector<std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>>
runtime_memory_ptrs_;
dnnl::memory::desc b_target_mem_desc_;
int64_t primitive_cache_size_;
};
class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
public:
enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
struct Args : public DNNLMatMulPrimitiveHandler::Args {
bool use_a_zero_point;
QuantizationStrategy a_quantization_strategy;
QuantizationStrategy b_quantization_strategy;
float* b_scales_ptr;
};
struct ClassMatmulCacheKey {
dnnl_dim_t b_n_size;
dnnl_dim_t b_k_size;
QuantizationStrategy a_qs;
QuantizationStrategy b_qs;
bool use_azp;
dnnl::memory::data_type c_type;
friend bool operator==(const ClassMatmulCacheKey& l,
const ClassMatmulCacheKey& r);
};
struct MSizeCacheKey {
dnnl_dim_t a_m_size;
bool use_bias;
dnnl::memory::data_type bias_type;
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
};
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
using ClassMatmulCache =
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
struct ExecArgs : public MSizeCacheKey {
const int8_t* a_ptr;
const float* a_scales_ptr;
const int32_t* a_zero_points_ptr;
const void* bias_ptr;
void* c_ptr;
};
public:
W8A8MatMulPrimitiveHandler(const Args& args);
QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
bool get_input_use_zero_point() const { return use_azp_; }
void execute(ExecArgs& args);
private:
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
bool first_time);
void init_runtime_memory_cache(const Args& args);
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
private:
const bool use_azp_;
const QuantizationStrategy a_qs_;
const QuantizationStrategy b_qs_;
std::shared_ptr<MSizeCache> m_size_cache_;
};
class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
public:
struct Args : public DNNLMatMulPrimitiveHandler::Args {
dnnl::memory::data_type ab_type;
};
struct ClassMatmulCacheKey {
dnnl_dim_t b_n_size;
dnnl_dim_t b_k_size;
dnnl::memory::data_type b_type;
friend bool operator==(const ClassMatmulCacheKey& l,
const ClassMatmulCacheKey& r);
};
struct MSizeCacheKey {
dnnl_dim_t a_m_size;
dnnl_dim_t a_m_stride;
bool use_bias;
dnnl::memory::data_type bias_type;
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
};
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
using ClassMatmulCache =
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
struct ExecArgs : public MSizeCacheKey {
const void* a_ptr;
const void* bias_ptr;
void* c_ptr;
};
public:
MatMulPrimitiveHandler(const Args& args);
void execute(ExecArgs& args);
private:
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
bool first_time);
void init_runtime_memory_cache(const Args& args);
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
private:
std::shared_ptr<MSizeCache> m_size_cache_;
};
#endif

570
csrc/cpu/dnnl_kernels.cpp Normal file
View File

@@ -0,0 +1,570 @@
#include "cpu_types.hpp"
#include "dnnl_helper.h"
namespace {
template <typename scalar_t>
struct KernelVecType {
using load_vec_type = void;
using cvt_vec_type = void;
};
template <>
struct KernelVecType<float> {
using load_vec_type = vec_op::FP32Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct KernelVecType<c10::BFloat16> {
using load_vec_type = vec_op::BF16Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};
#endif
template <>
struct KernelVecType<c10::Half> {
#if defined(__powerpc64__) || defined(__s390x__)
// Power architecture-specific vector type
using load_vec_type = vec_op::FP32Vec16;
#else
// Fallback for other architectures
using load_vec_type = vec_op::FP16Vec16;
#endif
using cvt_vec_type = vec_op::FP32Vec16;
};
template <bool AZP, typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int32_t* azp,
const int64_t num_tokens,
const int64_t input_stride,
const int64_t hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t inv_scale(1.0 / *scale);
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
cvt_vec_t zp_vec;
if constexpr (AZP) {
zp_vec = cvt_vec_t(static_cast<float>(*azp));
}
#pragma omp parallel for
for (int64_t i = 0; i < num_tokens; ++i) {
int64_t j = 0;
const scalar_t* input_ptr = input + i * input_stride;
int8_t* output_ptr = output + i * hidden_size;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input_ptr + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output_ptr + j);
}
load_vec_t elems(input_ptr + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output_ptr + j, hidden_size - j);
}
}
template <bool AZP, typename scalar_t>
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, int32_t* azp,
const int64_t num_tokens,
const int64_t input_stride,
const int64_t hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
#pragma omp parallel for
for (int64_t i = 0; i < num_tokens; ++i) {
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
cvt_vec_t min_value(std::numeric_limits<float>::max());
{
int64_t j = 0;
const scalar_t* input_ptr = input + i * input_stride;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input_ptr + j);
cvt_vec_t elems_fp32(elems);
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
}
load_vec_t elems(input_ptr + j);
cvt_vec_t elems_fp32(elems);
if (j + vec_elem_num == hidden_size) {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
} else {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32, hidden_size - j);
min_value = min_value.min(elems_fp32, hidden_size - j);
} else {
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
}
}
}
float scale_val;
float azp_val = 0.0f;
if constexpr (AZP) {
float max_scalar = max_value.reduce_max();
float min_scalar = min_value.reduce_min();
scale_val = (max_scalar - min_scalar) / 255.0f;
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
azp[i] = azp_val;
scale[i] = scale_val;
} else {
scale_val = max_value.reduce_max() / 127.0f;
scale[i] = scale_val;
}
const cvt_vec_t inv_scale(1.0 / scale_val);
const cvt_vec_t azp_vec(azp_val);
{
int64_t j = 0;
const scalar_t* input_ptr = input + i * input_stride;
int8_t* output_ptr = output + i * hidden_size;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input_ptr + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output_ptr + j);
}
load_vec_t elems(input_ptr + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output_ptr + j, hidden_size - j);
}
}
}
template <bool AZP, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const int32_t* azp,
const float* azp_adj, const scalar_t* bias,
const int64_t num_tokens,
const int64_t hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
const int64_t thread_num = omp_get_max_threads();
if (num_tokens > thread_num) {
#pragma omp parallel for
for (int64_t i = 0; i < num_tokens; ++i) {
const float* input_ptr = input + i * hidden_size;
scalar_t* output_ptr = output + i * hidden_size;
int64_t j = 0;
cvt_vec_t token_scale_vec(a_scale[i]);
cvt_vec_t token_zp_scale_vec;
if constexpr (AZP) {
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
}
for (; j < hidden_size - vec_elem_num; ++j) {
cvt_vec_t elems_fp32(input_ptr + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
cvt_vec_t azp_adj_fp32(azp_adj + j);
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output_ptr + j);
}
cvt_vec_t elems_fp32(input_ptr + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
cvt_vec_t azp_adj_fp32(azp_adj + j);
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output_ptr + j, hidden_size - j);
}
} else {
const int64_t vec_iteration =
(hidden_size + vec_elem_num - 1) / vec_elem_num;
const int64_t vec_iteration_per_thread =
(vec_iteration + thread_num - 1) / thread_num;
const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
#pragma omp parallel for schedule(static, 1)
for (int64_t i = 0; i < thread_num; ++i) {
const int64_t start = elem_num_per_thread * i;
const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
for (int64_t j = 0; j < num_tokens; ++j) {
cvt_vec_t token_scale_vec(a_scale[j]);
cvt_vec_t token_zp_scale_vec;
if constexpr (AZP) {
float zp_scale_val = a_scale[j] * static_cast<float>(azp[j]);
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
}
int64_t k = start;
const float* input_ptr = input + j * hidden_size;
scalar_t* output_ptr = output + j * hidden_size;
for (; k < end - vec_elem_num; k += vec_elem_num) {
cvt_vec_t elems_fp32(input_ptr + k);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
cvt_vec_t azp_adj_fp32(azp_adj + k);
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + k);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output_ptr + k);
}
if (k < end) {
cvt_vec_t elems_fp32(input_ptr + k);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
cvt_vec_t azp_adj_fp32(azp_adj + k);
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + k);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output_ptr + k, end - k);
}
}
}
}
}
} // namespace
int64_t create_onednn_scaled_mm_handler(
const torch::Tensor& b, // [IC, OC], column-major
const torch::Tensor& b_scales, // [1] or [OC]
at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
int64_t primitive_cache_size) {
TORCH_CHECK(b.dim() == 2);
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(b_scales.is_contiguous());
W8A8MatMulPrimitiveHandler::Args args;
args.primitive_cache_size = primitive_cache_size;
if (b_scales.numel() == 1) {
args.b_quantization_strategy =
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
} else {
TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
args.b_quantization_strategy =
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
}
args.b_scales_ptr = b_scales.data_ptr<float>();
args.b_k_size = b.size(0);
args.b_k_stride = b.stride(0);
args.b_n_size = b.size(1);
args.b_n_stride = b.stride(1);
args.b_ptr = b.data_ptr<int8_t>();
if (dynamic_act_quant) {
// dynamic per-token, bias, A scales and A zps will be applied in outside.
args.a_quantization_strategy =
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
args.use_a_zero_point = false;
} else {
// static per-tensor
args.a_quantization_strategy =
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
args.use_a_zero_point = use_azp;
}
VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
[&] {
if (dynamic_act_quant) {
args.c_type = get_dnnl_type<float>();
} else {
args.c_type = get_dnnl_type<scalar_t>();
}
});
return reinterpret_cast<int64_t>(new W8A8MatMulPrimitiveHandler(args));
}
void onednn_scaled_mm(
torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const torch::Tensor& a_scales, // [M] or [1]
const std::optional<torch::Tensor>& azp, // [M] or [1]
const std::optional<torch::Tensor>& azp_adj, // [M] or [1]
const std::optional<torch::Tensor>& bias, // [N]
int64_t handler) {
CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(c.is_contiguous());
W8A8MatMulPrimitiveHandler* ptr =
reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler);
const int32_t* azp_ptr = nullptr;
if (azp.has_value()) {
azp_ptr = azp->data_ptr<int32_t>();
}
if (ptr->get_input_scale_strategy() ==
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
TORCH_CHECK_EQ(a_scales.numel(), 1);
}
W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
exec_args.a_ptr = a.data_ptr<int8_t>();
exec_args.a_m_size = a.size(0);
exec_args.bias_ptr = nullptr;
exec_args.bias_type = get_dnnl_type<void>();
exec_args.use_bias = false;
exec_args.a_scales_ptr = nullptr;
exec_args.a_zero_points_ptr = nullptr;
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
if (ptr->get_input_scale_strategy() ==
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
if (bias.has_value()) {
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
exec_args.bias_type = get_dnnl_type<scalar_t>();
exec_args.use_bias = true;
}
exec_args.a_scales_ptr = a_scales.data_ptr<float>();
exec_args.a_zero_points_ptr = azp_ptr;
exec_args.c_ptr = c.data_ptr<scalar_t>();
ptr->execute(exec_args);
} else if (ptr->get_input_scale_strategy() ==
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
torch::Tensor tmp_fp32_out =
torch::empty_like(c, ::at::ScalarType::Float);
exec_args.c_ptr = tmp_fp32_out.data_ptr<float>();
ptr->execute(exec_args);
if (bias.has_value()) {
if (azp.has_value()) {
dynamic_quant_epilogue<true, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
} else {
dynamic_quant_epilogue<false, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), azp_ptr, nullptr,
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
}
} else {
if (azp.has_value()) {
dynamic_quant_epilogue<true, false>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
(scalar_t*)nullptr, c.size(0), c.size(1));
} else {
dynamic_quant_epilogue<false, false>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), azp_ptr, nullptr, (scalar_t*)nullptr,
c.size(0), c.size(1));
}
}
} else {
TORCH_CHECK(false, "invalid act quant type.");
}
});
}
// static-per-tensor quantization.
void static_scaled_int8_quant(
torch::Tensor& out, // [batch, hidden_size]
const torch::Tensor& input, // [batch, hidden_size]
const torch::Tensor& scale, std::optional<torch::Tensor> const& azp) {
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK_EQ(input.dim(), 2);
TORCH_CHECK_EQ(input.stride(1), 1);
TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
const int64_t stride = input.stride(0);
const int64_t hidden_size = input.size(1);
const int64_t num_tokens = input.size(0);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
if (azp.has_value()) {
static_scaled_int8_quant_impl<true>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
stride, hidden_size);
} else {
static_scaled_int8_quant_impl<false>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale.data_ptr<float>(), nullptr,
num_tokens, stride, hidden_size);
}
});
}
// dynamic-per-token quantization.
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [batch, hidden_size]
const torch::Tensor& input, // [batch, hidden_size]
torch::Tensor& scale, // [batch, 1]
std::optional<torch::Tensor> const& azp) {
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK_EQ(input.dim(), 2);
TORCH_CHECK_EQ(input.stride(1), 1);
const int64_t hidden_size = input.size(1);
const int64_t num_tokens = input.size(0);
const int64_t stride = input.stride(0);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
if (azp.has_value()) {
dynamic_scaled_int8_quant_impl<true>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
stride, hidden_size);
} else {
dynamic_scaled_int8_quant_impl<false>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), nullptr, num_tokens, stride,
hidden_size);
}
});
}
int64_t create_onednn_mm_handler(const torch::Tensor& b,
int64_t primitive_cache_size) {
TORCH_CHECK(b.dim() == 2);
MatMulPrimitiveHandler::Args args;
args.primitive_cache_size = primitive_cache_size;
args.b_k_size = b.size(0);
args.b_k_stride = b.stride(0);
args.b_n_size = b.size(1);
args.b_n_stride = b.stride(1);
args.b_ptr = b.data_ptr();
VLLM_DISPATCH_FLOATING_TYPES(b.scalar_type(), "create_onednn_mm_handler",
[&] {
args.c_type = get_dnnl_type<scalar_t>();
args.ab_type = get_dnnl_type<scalar_t>();
});
return reinterpret_cast<int64_t>(new MatMulPrimitiveHandler(args));
}
void onednn_mm(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const std::optional<torch::Tensor>& bias, int64_t handler) {
CPU_KERNEL_GUARD_IN(onednn_mm)
TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.stride(-1) == 1);
TORCH_CHECK(c.stride(-1) == 1);
MatMulPrimitiveHandler* ptr =
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
// ACL matmuls expect contiguous source tensors
#ifdef VLLM_USE_ACL
torch::Tensor a_contig = a.contiguous();
#endif
MatMulPrimitiveHandler::ExecArgs exec_args;
#ifdef VLLM_USE_ACL
exec_args.a_m_size = a_contig.size(0);
exec_args.a_m_stride = a_contig.stride(0);
#else
exec_args.a_m_size = a.size(0);
exec_args.a_m_stride = a.stride(0);
#endif
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
if (bias.has_value()) {
exec_args.use_bias = true;
exec_args.bias_type = get_dnnl_type<scalar_t>();
#ifdef VLLM_USE_ACL
// ACL matmuls in oneDNN do not support a bias.
// We handle a matmul with bias by doing: c = bias; c += matmul(a, b)
c.copy_(bias.value());
#else
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
#endif
} else {
exec_args.use_bias = false;
exec_args.bias_type = get_dnnl_type<void>();
exec_args.bias_ptr = nullptr;
}
#ifdef VLLM_USE_ACL
exec_args.a_ptr = a_contig.data_ptr<scalar_t>();
#else
exec_args.a_ptr = a.data_ptr<scalar_t>();
#endif
exec_args.c_ptr = c.data_ptr<scalar_t>();
ptr->execute(exec_args);
});
}

106
csrc/cpu/float_convert.hpp Normal file
View File

@@ -0,0 +1,106 @@
static float bf16_to_float(uint16_t bf16) {
uint32_t bits = static_cast<uint32_t>(bf16) << 16;
float fp32;
std::memcpy(&fp32, &bits, sizeof(fp32));
return fp32;
}
static uint16_t float_to_bf16(float fp32) {
uint32_t bits;
std::memcpy(&bits, &fp32, sizeof(fp32));
return static_cast<uint16_t>(bits >> 16);
}
/************************************************
* Copyright (c) 2015 Princeton Vision Group
* Licensed under the MIT license.
* Codes below copied from
* https://github.com/PrincetonVision/marvin/tree/master/tools/tensorIO_matlab
*************************************************/
static uint16_t float_to_fp16(float fp32) {
uint16_t fp16;
unsigned x;
unsigned u, remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;
std::memcpy(&x, &fp32, sizeof(fp32));
u = (x & 0x7fffffff);
// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
fp16 = 0x7fffU;
return fp16;
}
sign = ((x >> 16) & 0x8000);
// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
fp16 = sign | 0x7c00U;
return fp16;
}
if (u < 0x33000001) {
fp16 = (sign | 0x0000);
return fp16;
}
exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);
if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
} else {
shift = 0x7e - exponent;
exponent = 0;
mantissa |= 0x800000;
}
lsb = (1 << shift);
lsb_s1 = (lsb >> 1);
lsb_m1 = (lsb - 1);
// Round to nearest even.
remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
++exponent;
mantissa = 0;
}
}
fp16 = (sign | (exponent << 10) | mantissa);
return fp16;
}
static float fp16_to_float(uint16_t fp16) {
unsigned sign = ((fp16 >> 15) & 1);
unsigned exponent = ((fp16 >> 10) & 0x1f);
unsigned mantissa = ((fp16 & 0x3ff) << 13);
int temp;
float fp32;
if (exponent == 0x1f) { /* NaN or Inf */
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
exponent = 0xff;
} else if (!exponent) { /* Denorm or Zero */
if (mantissa) {
unsigned int msb;
exponent = 0x71;
do {
msb = (mantissa & 0x400000);
mantissa <<= 1; /* normalize */
--exponent;
} while (!msb);
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
}
} else {
exponent += 0x70;
}
temp = ((sign << 31) | (exponent << 23) | mantissa);
std::memcpy(&fp32, &temp, sizeof(temp));
return fp32;
}

117
csrc/cpu/layernorm.cpp Normal file
View File

@@ -0,0 +1,117 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void rms_norm_impl(scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
vec_op::FP32Vec8 variance(0.0);
auto input_p = input + i * hidden_size;
auto output_p = out + i * hidden_size;
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
vec_op::FP32Vec8 fp32_x(x);
variance = variance + fp32_x * fp32_x;
}
float s_variance =
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
vec_op::FP32Vec8 fp32_s_variance(s_variance);
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
scalar_vec_t w(weight + j);
vec_op::FP32Vec8 fp32_x(x);
vec_op::FP32Vec8 fp32_w(w);
vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
scalar_vec_t out(fp32_out);
out.save(output_p + j);
}
}
}
template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
scalar_t* __restrict__ residual,
const scalar_t* __restrict__ weight,
const float epsilon, const int num_tokens,
const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
vec_op::FP32Vec8 variance(0.0);
auto input_p = input + i * hidden_size;
auto residual_p = residual + i * hidden_size;
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
scalar_vec_t res(residual_p + j);
vec_op::FP32Vec8 fp32_x(x);
vec_op::FP32Vec8 fp32_res(res);
fp32_x = fp32_x + fp32_res;
variance = variance + fp32_x * fp32_x;
scalar_vec_t out(fp32_x);
out.save(residual_p + j);
}
float s_variance =
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
vec_op::FP32Vec8 fp32_s_variance(s_variance);
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t w(weight + j);
scalar_vec_t res(residual_p + j);
vec_op::FP32Vec8 fp32_w(w);
vec_op::FP32Vec8 fp32_res(res);
vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
scalar_vec_t out(fp32_out);
out.save(input_p + j);
}
}
}
} // namespace
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
});
}
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "fused_add_rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
fused_add_rms_norm_impl(
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
});
}

View File

@@ -0,0 +1,245 @@
#ifndef CPU_MICRO_GEMM_AMX_HPP
#define CPU_MICRO_GEMM_AMX_HPP
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
namespace cpu_micro_gemm {
namespace {
// AMX specific
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
typedef struct __tile_config {
uint8_t palette_id = 1;
uint8_t start_row = 0;
uint8_t reserved_0[14] = {0};
uint16_t colsb[16] = {0};
uint8_t rows[16] = {0};
} __tilecfg;
// 2-2-4 pattern, for 16 < m <= 32
// TILE 0, 1: load A matrix, row num should be 16, m - 16
// TILE 2, 3: load B matrix, row num should be 16
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
// - 16
template <typename scalar_t>
class TileGemm224 {
public:
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
}
};
template <>
class TileGemm224<c10::BFloat16> {
public:
using scalar_t = c10::BFloat16;
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
c10::BFloat16* __restrict__ a_tile_1 = a_ptr + lda * AMX_TILE_ROW_NUM;
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
// B is always packed as 16 output channels block
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
float* __restrict__ c_tile_4 = c_ptr;
float* __restrict__ c_tile_5 =
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
float* __restrict__ c_tile_6 = c_ptr + AMX_TILE_ROW_NUM * ldc;
float* __restrict__ c_tile_7 =
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
const int32_t c_tile_stride = ldc * sizeof(float);
if (accum_c) {
_tile_loadd(4, c_tile_4, c_tile_stride);
_tile_loadd(5, c_tile_5, c_tile_stride);
_tile_loadd(6, c_tile_6, c_tile_stride);
_tile_loadd(7, c_tile_7, c_tile_stride);
} else {
_tile_zero(4);
_tile_zero(5);
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
_tile_dpbf16ps(4, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
_tile_dpbf16ps(5, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_dpbf16ps(6, 1, 2);
_tile_dpbf16ps(7, 1, 3);
// update ptrs
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
_tile_stored(4, c_tile_4, c_tile_stride);
_tile_stored(5, c_tile_5, c_tile_stride);
_tile_stored(6, c_tile_6, c_tile_stride);
_tile_stored(7, c_tile_7, c_tile_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
const int32_t m_0 = AMX_TILE_ROW_NUM;
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
config.rows[0] = m_0;
config.rows[1] = m_1;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = m_0;
config.rows[5] = m_0;
config.rows[6] = m_1;
config.rows[7] = m_1;
_tile_loadconfig(&config);
}
};
// 1-2-2 pattern, for 0 < m <= 16
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
// m, m
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
// num should be 16
// TILE 6, 7, (6, 7): store results C matrix, row num should be
// m
template <typename scalar_t>
class TileGemm122 {
public:
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
}
};
template <>
class TileGemm122<c10::BFloat16> {
public:
using scalar_t = c10::BFloat16;
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
c10::BFloat16* __restrict__ a_tile_1 =
a_ptr + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
c10::BFloat16* __restrict__ b_tile_4 =
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
c10::BFloat16* __restrict__ b_tile_5 =
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
int64_t b_stride = AMX_TILE_ROW_BYTES;
float* __restrict__ c_tile_6 = c_ptr;
float* __restrict__ c_tile_7 = c_ptr + AMX_TILE_ROW_BYTES / sizeof(float);
int64_t c_stride = ldc * sizeof(float);
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
const int32_t k_group_times = k_times / 2;
const bool has_tail = (k_times % 2 == 1);
if (accum_c) {
_tile_loadd(6, c_tile_6, c_stride);
_tile_loadd(7, c_tile_7, c_stride);
} else {
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_group_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_stream_loadd(4, b_tile_4, b_stride);
_tile_dpbf16ps(6, 1, 4);
_tile_stream_loadd(5, b_tile_5, b_stride);
_tile_dpbf16ps(7, 1, 5);
// update ptrs
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
if (has_tail) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
}
_tile_stored(6, c_tile_6, c_stride);
_tile_stored(7, c_tile_7, c_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
config.rows[0] = m;
config.rows[1] = m;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = AMX_TILE_ROW_NUM;
config.rows[5] = AMX_TILE_ROW_NUM;
config.rows[6] = m;
config.rows[7] = m;
_tile_loadconfig(&config);
}
};
} // namespace
// Gemm kernel uses AMX, requires B matrix to be packed
template <typename scalar_t>
class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
public:
static constexpr int32_t MaxMSize = 32;
static constexpr int32_t NSize = 32;
public:
MicroGemm() : curr_m_(-1) {
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
}
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
if (m > AMX_TILE_ROW_NUM) {
if (m != curr_m_) {
curr_m_ = m;
TileGemm224<scalar_t>::init_tile_config(m, amx_tile_config_);
}
TileGemm224<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
} else {
if (m != curr_m_) {
curr_m_ = m;
TileGemm122<scalar_t>::init_tile_config(m, amx_tile_config_);
}
TileGemm122<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
}
}
private:
alignas(64) __tilecfg amx_tile_config_;
int32_t curr_m_;
};
} // namespace cpu_micro_gemm
#endif

View File

@@ -0,0 +1,91 @@
#ifndef CPU_MICRO_GEMM_IMPL_HPP
#define CPU_MICRO_GEMM_IMPL_HPP
#include "cpu/utils.hpp"
#include "cpu/cpu_types.hpp"
namespace cpu_micro_gemm {
#define DEFINE_CPU_MICRO_GEMM_PARAMS \
scalar_t *__restrict__ a_ptr, scalar_t *__restrict__ b_ptr, \
float *__restrict__ c_ptr, const int32_t m, const int32_t k, \
const int64_t lda, const int64_t b_n_group_stride, const int64_t ldc, \
const bool accum_c
#define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
template <cpu_utils::ISA isa, typename scalar_t>
class MicroGemm {
public:
static constexpr int32_t MaxMSize = 16;
static constexpr int32_t NSize = 16;
public:
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TORCH_CHECK(false, "Unimplemented MicroGemm.");
}
};
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void default_epilogue(float* __restrict__ c_ptr,
scalar_t* __restrict__ d_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
float* __restrict__ curr_c = c_ptr;
scalar_t* __restrict__ curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* __restrict__ curr_c_iter = curr_c;
scalar_t* __restrict__ curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_size / 16>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
scalar_vec_t c_vec(c_vec_fp32);
c_vec.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
scalar_t* __restrict__ d_ptr,
scalar_t* __restrict__ bias_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
constexpr int32_t n_group_num = n_size / 16;
static_assert(n_group_num <= 16);
vec_op::FP32Vec16 bias_vecs[n_group_num];
scalar_t* __restrict__ curr_bias = bias_ptr;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
scalar_vec_t vec(curr_bias);
bias_vecs[i] = vec_op::FP32Vec16(vec);
curr_bias += 16;
});
float* __restrict__ curr_c = c_ptr;
scalar_t* __restrict__ curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* __restrict__ curr_c_iter = curr_c;
scalar_t* __restrict__ curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
scalar_vec_t c_vec(c_vec_fp32);
c_vec.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
} // namespace cpu_micro_gemm
#endif

View File

@@ -0,0 +1,115 @@
#ifndef CPU_MICRO_GEMM_VEC_HPP
#define CPU_MICRO_GEMM_VEC_HPP
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
namespace cpu_micro_gemm {
namespace {
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
template <typename scalar_t>
class TileGemm82 {
public:
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
switch (m) {
case 1:
gemm_micro<1>(CPU_MICRO_GEMM_PARAMS);
break;
case 2:
gemm_micro<2>(CPU_MICRO_GEMM_PARAMS);
break;
case 3:
gemm_micro<3>(CPU_MICRO_GEMM_PARAMS);
break;
case 4:
gemm_micro<4>(CPU_MICRO_GEMM_PARAMS);
break;
case 5:
gemm_micro<5>(CPU_MICRO_GEMM_PARAMS);
break;
case 6:
gemm_micro<6>(CPU_MICRO_GEMM_PARAMS);
break;
case 7:
gemm_micro<7>(CPU_MICRO_GEMM_PARAMS);
break;
case 8:
gemm_micro<8>(CPU_MICRO_GEMM_PARAMS);
break;
}
}
template <int32_t M>
static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) {
static_assert(0 < M <= 8);
using load_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
scalar_t* __restrict__ curr_b_0 = b_ptr;
scalar_t* __restrict__ curr_b_1 = b_ptr + b_n_group_stride;
float* __restrict__ curr_c_0 = c_ptr;
float* __restrict__ curr_c_1 = c_ptr + 16;
vec_op::FP32Vec16 c_regs[M * 2];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
float* __restrict__ curr_m_c_1 = curr_c_1;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
// update
curr_m_c_0 += ldc;
curr_m_c_1 += ldc;
});
}
scalar_t* __restrict__ curr_a = a_ptr;
for (int32_t k_idx = 0; k_idx < k; ++k_idx) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
load_vec_t b_1_reg(curr_b_1);
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
scalar_t* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
scalar_t v = *curr_m_a;
load_vec_t a_reg_original(v);
vec_op::FP32Vec16 a_reg(a_reg_original);
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += 16;
curr_b_1 += 16;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2].save(curr_c_0);
c_regs[i * 2 + 1].save(curr_c_1);
// update
curr_c_0 += ldc;
curr_c_1 += ldc;
});
}
};
} // namespace
// Gemm kernel uses vector instructions, requires B matrix to be packed
template <typename scalar_t>
class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
public:
static constexpr int32_t MaxMSize = 8;
static constexpr int32_t NSize = 32;
public:
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
}
};
} // namespace cpu_micro_gemm
#endif

393
csrc/cpu/mla_decode.cpp Normal file
View File

@@ -0,0 +1,393 @@
#include "cpu_types.hpp"
#include <float.h>
namespace {
template <typename scalar_t>
struct KernelVecType {
using qk_load_vec_type = void;
using qk_vec_type = void;
using v_load_vec_type = void;
};
template <>
struct KernelVecType<float> {
using qk_load_vec_type = vec_op::FP32Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::Half> {
#if defined(__powerpc64__) || defined(__s390x__)
// Power and s390x architecture-specific vector types
using qk_load_vec_type = vec_op::FP32Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
#else
// Fallback for other architectures, including x86
using qk_load_vec_type = vec_op::FP16Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP16Vec16;
#endif
};
#ifdef __AVX512BF16__
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec32;
using qk_vec_type = vec_op::BF16Vec32;
using v_load_vec_type = vec_op::BF16Vec16;
};
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// pass
#else
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
typename qk_vec_type>
void mla_decode_block_head(
const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim]
const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim]
const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim]
float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim]
float* __restrict__ acc_lse, // [HEAD_UNROLL]
const float scale, const int num_tokens) {
using f32_vec_type = vec_op::FP32Vec16;
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM;
float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros
float max_val[HEAD_UNROLL];
std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX);
f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL];
for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) {
// load to registers
qk_vec_type q_vec[HEAD_UNROLL];
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
q_vec[unroll] =
qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]};
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]);
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec);
}
}
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale;
logits[block_offset][unroll] = acc;
max_val[unroll] = std::max(max_val[unroll], acc);
}
}
float sum_exp[HEAD_UNROLL] = {};
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
const float val =
std::exp(logits[block_offset][unroll] - max_val[unroll]);
logits[block_offset][unroll] = val;
sum_exp[unroll] += val;
}
}
f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL];
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
// load to registers
f32_vec_type scale_[HEAD_UNROLL];
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
scale_[unroll] =
f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]};
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
f32_vec_type v_vec(
v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]);
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]);
}
}
// merge attention state
// section 2.2 in https://arxiv.org/pdf/2501.01005
f32_vec_type prev_scale[HEAD_UNROLL];
f32_vec_type curr_scale[HEAD_UNROLL];
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
const float prev_lse = acc_lse[unroll];
const float curr_lse = std::log(sum_exp[unroll]) +
max_val[unroll]; // add back max_val to get true lse
// softmax trick
const float max_lse = std::max(prev_lse, curr_lse);
const float prev_sum_exp = std::exp(prev_lse - max_lse);
const float curr_sum_exp = std::exp(curr_lse - max_lse);
const float new_sum_exp = prev_sum_exp + curr_sum_exp;
acc_lse[unroll] = std::log(new_sum_exp) + max_lse;
prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp};
curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp};
}
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll);
o_vec = o_vec * prev_scale[unroll] +
this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll];
o_vec.save(acc_out + i + V_HEAD_DIM * unroll);
}
}
q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL;
acc_out += V_HEAD_DIM * HEAD_UNROLL;
}
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE,
typename qk_vec_type>
void mla_decode_block(
const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim]
const scalar_t* __restrict__ kv_cache, // [block_size, head_dim]
float* __restrict__ acc_out, // [num_heads, v_head_dim]
float* __restrict__ acc_lse, // [num_heads]
const int num_heads, const float scale, const int num_tokens) {
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
static_assert(
std::is_same<qk_vec_type,
typename KernelVecType<scalar_t>::qk_vec_type>::value);
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
using f32_vec_type = vec_op::FP32Vec16;
static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM);
static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM);
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM;
const qk_vec_type* k_vecs;
const f32_vec_type* v_vecs_f32;
float* kv_cache_f32 = nullptr;
if constexpr (!std::is_same<scalar_t, float>::value) {
// convert KV cache block to FP32 to reuse it across query heads and
// attn @ V computation, since FP16/BF16->FP32 is expensive.
// TODO: move malloc outside of this fn to reuse across iterations.
const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float);
kv_cache_f32 = static_cast<float*>(std::aligned_alloc(64, nbytes));
for (int block_offset = 0; block_offset < num_tokens; ++block_offset)
for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) {
v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i);
f32_vec_type kv_vec_f32(kv_load_vec);
kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i);
}
if constexpr (std::is_same<qk_load_vec_type, qk_vec_type>::value) {
// for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion)
// NOTE: in this case, we only need to convert the V section to FP32.
// But for simplicity, we will convert the whole KV block to FP32.
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
} else {
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache_f32);
}
// attn @ V always use FP32 for V, since attn is FP32.
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache_f32);
} else {
// KV cache is FP32. don't need to do anything.
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache);
}
// compute 2 heads at the same time to improve ILP and
// take advantage of register cache for K and V.
constexpr int HEAD_UNROLL = 2;
for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) {
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, HEAD_UNROLL>(
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM;
acc_out += HEAD_UNROLL * V_HEAD_DIM;
acc_lse += HEAD_UNROLL;
}
// take care of the remaining heads
for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) {
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, 1>(
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
q_vecs += HEAD_DIM / QK_NUM_ELEM;
acc_out += V_HEAD_DIM;
acc_lse += 1;
}
if (kv_cache_f32 != nullptr) {
std::free(kv_cache_f32);
}
}
} // namespace
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
void mla_decode_kvcache_cpu_impl(
scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim]
const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size,
// head_dim]
const int num_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int o_stride, const int q_stride,
const int kv_stride, const int num_seqs) {
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
using qk_vec_type = typename KernelVecType<scalar_t>::qk_vec_type;
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
// shared across threads
const int max_threads = omp_get_max_threads();
const int acc_out_nbytes =
max_threads * num_heads * V_HEAD_DIM * sizeof(float);
float* acc_out = static_cast<float*>(std::aligned_alloc(64, acc_out_nbytes));
std::vector<float> acc_lse(max_threads * num_heads);
// allocate memory to pre-convert query to FP32 later
float* q_f32;
constexpr bool PRE_CONVERT_QUERY =
!std::is_same<scalar_t, float>::value &&
std::is_same<qk_vec_type, vec_op::FP32Vec16>::value;
if constexpr (PRE_CONVERT_QUERY) {
const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float);
q_f32 = static_cast<float*>(std::aligned_alloc(64, q_f32_nbytes));
}
#pragma omp parallel
{
const int num_threads = omp_get_num_threads();
const int thread_id = omp_get_thread_num();
float* __restrict__ acc_out_thread =
acc_out + thread_id * num_heads * V_HEAD_DIM;
float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads;
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
// reset accumulator
std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f);
std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX);
const int seq_len = seq_lens[seq_idx];
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE;
const qk_vec_type* q_vecs;
if constexpr (PRE_CONVERT_QUERY) {
// pre-convert query to FP32 since FP16/BF16->FP32 is slow.
#pragma omp for
for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) {
qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i);
qk_vec_type q_vec(q_load_vec);
q_vec.save(q_f32 + i);
}
q_vecs = reinterpret_cast<const qk_vec_type*>(q_f32);
} else {
q_vecs = reinterpret_cast<const qk_vec_type*>(q + seq_idx * q_stride);
}
#pragma omp for
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int physical_block_idx =
block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
const int num_tokens =
block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size;
mla_decode_block<scalar_t, HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE>(
q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread,
acc_lse_thread, num_heads, scale, num_tokens);
}
// merge attention states across threads
// section 2.2 in https://arxiv.org/pdf/2501.01005
// each thread is responsible for 1 head
#pragma omp for
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
float* acc_lse_head = acc_lse.data() + head_idx;
float* acc_out_head = acc_out + head_idx * V_HEAD_DIM;
float max_val = -FLT_MAX;
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]);
}
float sum_exp = 0.0f;
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val);
acc_lse_head[thread_id_ * num_heads] = val;
sum_exp += val;
}
float inv_sum = 1.0f / sum_exp;
float out_head[V_HEAD_DIM] = {};
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum;
for (int i = 0; i < V_HEAD_DIM; ++i) {
out_head[i] +=
acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_;
}
}
for (int i = 0; i < V_HEAD_DIM; ++i) {
vec_op::storeFP32(out_head[i], out + seq_idx * o_stride +
head_idx * V_HEAD_DIM + i);
}
}
}
}
if (PRE_CONVERT_QUERY) {
std::free(q_f32);
}
std::free(acc_out);
}
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens) {
const int num_seqs = query.size(0);
const int num_heads = query.size(1);
const int head_dim = query.size(2);
const int block_size = kv_cache.size(1);
const int v_head_dim = out.size(2);
const int max_num_blocks_per_seq = block_tables.size(1);
const int o_stride = out.stride(0);
const int q_stride = query.stride(0);
const int kv_stride = kv_cache.stride(0);
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl)
if (head_dim == 576 && v_head_dim == 512 && block_size == 16)
mla_decode_kvcache_cpu_impl<scalar_t, 576, 512, 16>(
out.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
kv_cache.data_ptr<scalar_t>(), num_heads, scale,
block_tables.data_ptr<int>(), seq_lens.data_ptr<int>(),
max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs);
else
TORCH_CHECK(false, "Unsupported block size: ", block_size);
CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl)
});
}

208
csrc/cpu/pos_encoding.cpp Normal file
View File

@@ -0,0 +1,208 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void rotary_embedding_impl(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
const int embed_dim = rot_dim / 2;
bool flag = (embed_dim % VEC_ELEM_NUM == 0);
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
scalar_t* qk) {
int j = 0;
for (; j < loop_upper; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(qk + out_x);
const scalar_vec_t q_y(qk + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(qk + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
const int x_index = j;
const int y_index = embed_dim + j;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const float fp32_cos = cache_ptr[x_index];
const float fp32_sin = cache_ptr[y_index];
const float fp32_q_x = qk[out_x];
const float fp32_q_y = qk[out_y];
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
}
}
};
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, query);
}
if (key != nullptr) {
for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
}
}
}
template <typename scalar_t>
void rotary_embedding_gptj_impl(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
const int embed_dim = rot_dim / 2;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
scalar_t* head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
const int y_index = 2 * rot_offset + 1;
const float cos = cos_cache_ptr[rot_offset];
const float sin = sin_cache_ptr[rot_offset];
const float x = head_query[x_index];
const float y = head_query[y_index];
head_query[x_index] = x * cos - y * sin;
head_query[y_index] = y * cos + x * sin;
}
}
}
if (key == nullptr) {
return;
}
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t* head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
const int y_index = 2 * rot_offset + 1;
const float cos = cos_cache_ptr[rot_offset];
const float sin = sin_cache_ptr[rot_offset];
const float x = head_key[x_index];
const float y = head_key[y_index];
head_key[x_index] = x * cos - y * sin;
head_key[y_index] = y * cos + x * sin;
}
}
}
}
}; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
int64_t query_stride = query.stride(-2);
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "rotary_embedding_impl", [&] {
CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
if (is_neox) {
rotary_embedding_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
} else {
rotary_embedding_gptj_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
}
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
});
}

View File

@@ -0,0 +1,23 @@
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}

View File

@@ -0,0 +1,31 @@
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
#endif

View File

@@ -0,0 +1,238 @@
// Adapted from
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/record_function.h>
// clang-format off
#if defined(_OPENMP)
#include <omp.h>
#endif
namespace {
// dispatch bool
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
[&] { \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME = false; \
return __VA_ARGS__(); \
} \
}()
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::BFloat16 : { \
using packed_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using packed_t = at::Half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Char : { \
using packed_t = int8_t; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Float8_e4m3fn : { \
using packed_t = at::Float8_e4m3fn; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
}()
#define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CPU(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// parallel routines
constexpr int GRAIN_SIZE = 1024;
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) { return (x + y - 1) / y; }
template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0
// onednn partition pattern
T& n_my = n_end;
if (nth <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else {
T n1 = div_up(n, nth);
T n2 = n1 - 1;
T T1 = n - n2 * nth;
n_my = ith < T1 ? n1 : n2;
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
}
n_end += n_start;
#else
// pytorch aten partition pattern
T n_my = div_up(n, nth);
n_start = ith * n_my;
n_end = std::min(n_start + n_my, n);
#endif
}
template <typename func_t>
inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP)
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
#endif
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) {
int actual_nth = at::get_num_threads();
if (m == 1) {
return actual_nth;
}
return std::max(1, (actual_nth >> 1) * 2);
}
template <typename func_t>
inline void parallel_2d(int m, int n, const func_t& f) {
// make sure we have even num_threads
int nth = adjust_num_threads(m);
// [NOTE] thread blocking:
//
// 1) prefer square block per thread
// 2) use even number of CPU cores
// 3) use all `num_threads` cores
//
// we have:
// TM * TN = T
// BM / TM = BN / TN
// then:
// TM = ((BM / BN) * T) ^ 0.5
//
float r = float(m) / n;
int nth_m = std::ceil(std::sqrt(r * nth));
int nth_n = 1;
for (; nth_m > 0; --nth_m) {
nth_n = nth / nth_m;
if (nth_m * nth_n == nth) {
break;
}
}
#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
int ith = omp_get_thread_num();
int ith_m = ith / nth_n;
int ith_n = ith % nth_n;
int thread_block_m = div_up(m, nth_m);
int thread_block_n = div_up(n, nth_n);
int begin_m = ith_m * thread_block_m;
int end_m = std::min(m, begin_m + thread_block_m);
int begin_n = ith_n * thread_block_n;
int end_n = std::min(n, begin_n + thread_block_n);
f(begin_m, end_m, begin_n, end_n);
}
#else
f(0, m, 0, n);
#endif
}
template <typename T>
int get_cache_blocks(int BLOCK_SIZE, int K) {
// L2 2MB and ratio of 50%
const int L2_size = 2048 * 1024 >> 1;
return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T))));
}
// data indexing for dimension collapse
template <typename T>
inline T data_index_init(T offset) {
return offset;
}
template <typename T, typename... Args>
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
offset = data_index_init(offset, std::forward<Args>(args)...);
x = offset % X;
return offset / X;
}
inline bool data_index_step() {
return true;
}
template <typename T, typename... Args>
inline bool data_index_step(T& x, const T& X, Args&&... args) {
if (data_index_step(std::forward<Args>(args)...)) {
x = ((x + 1) == X) ? 0 : (x + 1);
return x == 0;
}
return false;
}
// forced unroll for perf critical path
#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif
template <int n>
struct Unroll {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};
template <>
struct Unroll<1> {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
f(std::integral_constant<int, 0>{}, args...);
}
};
} // anonymous namespace

View File

@@ -0,0 +1,464 @@
// Adapted from
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
#include "common.h"
#include "vec.h"
#include "gemm.h"
// clang-format off
namespace {
// packed layout:
// quants {N, K} int8_t
// comp {N} int32_t
template <int BLOCK_N>
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
__m512i vcomp[COLS];
for (int col = 0; col < COLS; ++col) {
vcomp[col] = _mm512_setzero_si512();
}
const int64_t offset = BLOCK_N * K;
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
for (int k = 0; k < K / 4; ++k) {
for (int col = 0; col < COLS; ++col) {
__m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64));
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
}
}
for (int col = 0; col < COLS; ++col) {
_mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]);
}
#else
TORCH_CHECK(false, "s8s8_compensation not implemented!");
#endif
}
// convert to vnni format
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
template <typename packed_t>
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
const int VNNI_BLK = 2;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
}
template <>
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
constexpr int BLOCK_N = block_size_n();
TORCH_CHECK(N == BLOCK_N);
const int VNNI_BLK = 4;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
s8s8_compensation<BLOCK_N>(packed, K);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C,
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_set1_ps(0.f);
}
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const float* b_ptr = reinterpret_cast<const float*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
// for COLS = 1, 3 use 256bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(C + row * ldc + col * 16),
(__m256i)(_mm512_cvtneps_pbh(vc[i])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
template <typename scalar_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
float* __restrict__ Ctmp, const float* __restrict__ bias,
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(
M, N, K, lda, ldb, BLOCK_N, /* add_C */false,
A, B, Ctmp);
// copy from Ctmp to C
for (int64_t m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
if (brg) {
brgemm<scalar_t, has_bias>::apply(
A, B, C, Ctmp, bias,
M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch(mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
// mb_size = 2
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
// mb_size = 3
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
// mb_size = 4
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void weight_packed_linear_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
// l2 cache block for n
int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
for (int64_t mb = begin_mb; mb < end_mb; ++mb) {
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
/* C */ out + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
}}}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \
int64_t ldb, int64_t ldc, bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor convert_weight_packed(at::Tensor& weight) {
// for 3d moe weights
// weight : [E, OC, IC]
// w1 : [E, 2N, K]
// w2 : [E, K, N]
CHECK_INPUT(weight);
const int64_t ndim = weight.ndimension();
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
const auto st = weight.scalar_type();
const int64_t E = ndim == 3 ? weight.size(0) : 1;
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
// we handle 2 TILE_N at a time.
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
constexpr int64_t BLOCK_N = block_size_n();
const int64_t NB = div_up(OC, BLOCK_N);
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
auto packed_weight = at::empty({}, weight.options());
const int64_t stride = OC * IC;
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size
const int packed_row_size = get_row_size<packed_t>(IC);
auto sizes = weight.sizes().vec();
sizes[ndim - 1] = packed_row_size;
packed_weight.resize_(sizes);
const packed_t* w_data = weight.data_ptr<packed_t>();
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
// parallel on {E, NB}
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
int64_t e{0}, nb{0};
data_index_init(begin, e, E, nb, NB);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int64_t n = nb * BLOCK_N;
int64_t n_size = std::min(BLOCK_N, OC - n);
pack_vnni<packed_t>(
packed_data + e * OC * packed_row_size + n * packed_row_size,
w_data + e * stride + n * IC,
n_size,
IC);
// move to the next index
data_index_step(e, E, nb, NB);
}
});
});
return packed_weight;
}
// mat1 : [M, K]
// mat2 : [N, K]
// bias : [N]
// out : [M, N]
//
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias, bool is_vnni) {
RECORD_FUNCTION(
"sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
auto out = at::empty({M, N}, mat1.options());
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
weight_packed_linear_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
bias_data,
M,
N,
K,
mat1_strideM,
out_strideM);
});
return out;
}

266
csrc/cpu/sgl-kernels/gemm.h Normal file
View File

@@ -0,0 +1,266 @@
#pragma once
#include <ATen/native/CPUBlas.h>
// clang-format off
// amx-bf16
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
// block size for AMX gemm
constexpr int block_size_m() { return 2 * TILE_M; }
constexpr int block_size_n() { return 2 * TILE_N; }
// define threshold using brgemm (intel AMX)
template <typename T> inline bool can_use_brgemm(int M);
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K
// adjust leading dimension size for K
template <typename T>
inline int64_t get_row_size(int64_t K) {
return K;
}
template <>
inline int64_t get_row_size<int8_t>(int64_t K) {
return K + sizeof(int32_t);
}
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
// pack weight to vnni format
at::Tensor convert_weight_packed(at::Tensor& weight);
// moe implementations for int8 w8a8
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// moe implementations for fp8 w8a16
template <typename scalar_t>
void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// moe implementations for int4 w4a16
template <typename scalar_t>
void fused_experts_int4_w4a16_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::quint4x2* __restrict__ packed_w1,
const at::quint4x2* __restrict__ packed_w2,
const uint8_t* __restrict__ w1z,
const uint8_t* __restrict__ w2z,
const scalar_t* __restrict__ w1s,
const scalar_t* __restrict__ w2s,
int group_size,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// shared expert implementation for int8 w8a8
template <typename scalar_t>
void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K);
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::quint4x2* __restrict__ B,
scalar_t* __restrict__ C,
const uint8_t* __restrict__ Bz,
const scalar_t* __restrict__ Bs,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int group_size,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t strideBz,
int64_t strideBs,
bool brg);
// TODO: debug print, remove me later
inline void print_16x32i(const __m512i x) {
int32_t a[16];
_mm512_storeu_si512((__m512i *)a, x);
for (int i = 0; i < 16; i++){
std::cout << a[i] << " ";
}
std::cout << std::endl;
}
inline void print_16x32(const __m512 x) {
float a[16];
_mm512_storeu_ps((__m512 *)a, x);
for (int i = 0; i < 16; i++){
std::cout << a[i] << " ";
}
std::cout << std::endl;
}
inline void print_32x8u(const __m256i x) {
uint8_t a[32];
_mm256_storeu_si256((__m256i *)a, x);
for (int i = 0; i < 32; ++i) {
std::cout << int32_t(a[i]) << " ";
}
std::cout << std::endl;
}

View File

@@ -0,0 +1,530 @@
// Adapted from
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
#include "common.h"
#include "vec.h"
#include "gemm.h"
// clang-format off
// we use 4x32 for BLOCK_M
#define BLOCK_SIZE_M_SCALE 4
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
inline void unpack_B(
at::BFloat16* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
int N,
int K,
int ldb,
int ldb_tmp,
float scale) {
#if defined(CPU_CAPABILITY_AVX512)
// [K/2, N, 2]
const int K2 = K >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
const __m512 vd = _mm512_set1_ps(scale);
constexpr int BLOCK_N = block_size_n();
static_assert(BLOCK_N == 32);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 64;
#pragma GCC unroll 4
for (int k = 0; k < K2; ++k) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
}
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
}
#else
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
#endif
}
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C,
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C,
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
const int KB = div_up(K, BLOCK_K);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 64;
constexpr int PREFETCH_SIZE_KB = 1;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vsum[ROWS * COLS];
// block quant scale
__m512 vscale;
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_setzero_ps();
}
};
Unroll<ROWS * COLS>{}(loadc);
const int lda2 = lda >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
}
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
}
}
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
};
constexpr int BLOCK_K2 = BLOCK_K >> 1;
for (int kb = 0; kb < KB; ++kb) {
int kb_start = kb * BLOCK_K2;
int kb_end = std::min(K, kb_start + BLOCK_K2);
// 1. load scale vector
vscale = _mm512_set1_ps(scale[kb]);
if constexpr (PREFETCH_SIZE_KB > 0) {
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
}
// 2. zero vsum for each block
Unroll<ROWS * COLS>{}([&](auto i) {
vsum[i] = _mm512_setzero_ps();
});
// 3. accumulate across each block
for (int k = kb_start; k < kb_end; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
// 4. apply scale
Unroll<ROWS * COLS>{}([&](auto i) {
vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]);
});
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2,4 use 512bit store
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K);
template <typename scalar_t, typename packed_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
}
};
template <bool has_bias>
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
at::BFloat16* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc) {
constexpr int BLOCK_N = block_size_n();
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = BLOCK_N;
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
}
at::native::cpublas::brgemm(
M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K) {
if (brg) {
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch(mb_size << 4 | nb_size >> 4) {
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void fp8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const at::Float8_e4m3fn* __restrict__ mat2,
const float* __restrict__ scales2,
const float* __restrict__ bias,
scalar_t* __restrict__ buffer,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM,
int64_t block_size_N,
int64_t block_size_K,
int64_t buffer_size_per_thread) {
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
int tid = at::get_thread_num();
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* scale */ scale_ptr,
/* bias */ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K) {
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const at::Float8_e4m3fn* __restrict__ B, \
TYPE* __restrict__ C, \
TYPE* __restrict__ Btmp, \
float* __restrict__ Ctmp, \
const float* __restrict__ scale, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg, \
int64_t block_size_K)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
std::vector<int64_t> block_size, std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
"fp8_scaled_mm_cpu: expect scales2 to be float32.");
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
TORCH_CHECK(block_size.size() == 2,
"fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
"fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype,
"fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn,
"fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
"fp8_scaled_mm_cpu: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads();
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<at::Float8_e4m3fn>(),
scales2.data_ptr<float>(),
bias_data,
buffer.data_ptr<scalar_t>(),
M,
N,
K,
mat1_strideM,
out_strideM,
block_size_N,
block_size_K,
size_per_thread);
});
return out;
}

View File

@@ -0,0 +1,440 @@
// Adapted from
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
#include "common.h"
#include "vec.h"
#include "gemm.h"
// clang-format off
namespace {
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C,
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512i va;
__m512i vb[COLS];
__m512i vc[ROWS * COLS];
__m512i vcomp[COLS];
__m512 vd0;
__m512 vd1[COLS];
// oops! 4x4 spills but luckily we use 4x2
__m512 vbias[COLS];
// [NOTE]: s8s8 igemm compensation in avx512-vnni
//
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
//
// a * b = (a + 128) * b - 128 * b
// s s u s u s
//
// 1) 128 * b is pre-computed when packing B to vnni formats
// 2) a + 128 is fused when dynamically quantize A
//
auto loadc = [&](auto i) {
vc[i] = _mm512_set1_epi32(0);
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr(col == 0) {
vd0 = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp per 2 vectors
// also load bias if any
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
if constexpr (has_bias) {
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
}
}
}
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
if constexpr (has_bias) {
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
} else {
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
}
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
// B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch(mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
// mb_size = 2
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
// mb_size = 3
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
// mb_size = 4
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template<typename scalar_t>
void int8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const uint8_t* __restrict__ mat1,
const int8_t* __restrict__ mat2,
const float* __restrict__ scales1,
const float* __restrict__ scales2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const bool use_brgemm = false;
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use int32_t for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * K,
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ out + mb_start * N + nb_start,
/* Ctmp*/ Ctmp,
/* As */ scales1 + mb_start,
/* Bs */ scales2 + nb_start,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ N,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs,
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
CHECK_DIM(2, A);
int64_t M = A.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
const auto st = A.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
"per_token_quant_int8: expect A to be bfloat16 or half.");
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
auto As = at::empty({M}, A.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
float* __restrict__ As_data = As.data_ptr<float>();
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(
Aq_data + m * K,
As_data[m],
A_data + m * lda,
K);
}
});
});
return std::make_tuple(Aq, As);
}
// weight : static, per-channel, symmetric
// activation : dynamic, per-token, symmetric
//
// mat1 : [M, K]
// mat2 : [N, K]
// scales1 : [M]
// scales2 : [N]
// bias : [N]
// out : [M, N]
//
at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2,
at::Tensor& scales1, at::Tensor& scales2,
std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales1);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales1.numel(), M);
CHECK_EQ(scales2.numel(), N);
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
"int8_scaled_mm: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<uint8_t>(),
packed_w.data_ptr<int8_t>(),
scales1.data_ptr<float>(),
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
const std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
int64_t lda = mat1.stride(0);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales2.numel(), N);
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
"int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype,
"int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kChar,
"int8_scaled_mm_with_quant: expect mat2 to be int8.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
"int8_scaled_mm_with_quant: expect scales to be float32.");
const int64_t buffer_size = M * K + M * sizeof(float);
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(
Aq_data + m * K,
As_data[m],
A_data + m * lda,
K);
}
});
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
Aq_data,
packed_w.data_ptr<int8_t>(),
As_data,
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}

1330
csrc/cpu/sgl-kernels/moe.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,502 @@
// Adapted from
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
#include "common.h"
#include "gemm.h"
#include "vec.h"
// clang-format off
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
Vec data = Vec::loadu(input + d);
data.store(out + d);
}
}
template <typename scalar_t>
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec weight_vec = fVec(weight);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
bVec x = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x);
x0 = x0 * weight_vec;
x1 = x1 * weight_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] * weight);
}
}
// acc from [topk, K] to [K]
template <typename scalar_t>
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
if (topk == 1) {
// do copy for topk = 1
copy_stub(out, input, K);
} else {
// do sum for topk != 1
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= K - kVecSize; d += kVecSize) {
fVec sum_fvec0 = fVec(0.f);
fVec sum_fvec1 = fVec(0.f);
for (int t = 0; t < topk; ++t) {
bVec x_bvec = bVec::loadu(input + t * K + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec0 += x_fvec0;
sum_fvec1 += x_fvec1;
}
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
out_bvec.store(out + d);
}
for (; d < K; ++d) {
float sum_val = 0.f;
for (int t = 0; t < topk; ++t) {
sum_val += static_cast<float>(input[t * K + d]);
}
out[d] = static_cast<scalar_t>(sum_val);
}
}
}
// out = input + input2 * scale
template <typename scalar_t>
inline void add_mul_stub(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ input2,
float scale,
int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_vec = fVec(scale);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
bVec y_bvec = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
x0 = x0 + y0 * s_vec;
x1 = x1 + y1 * s_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
}
}
template <typename scalar_t>
inline void silu_and_mul_stub(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ input2,
int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x);
bVec y = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y);
x0 = x0 / (one + x0.neg().exp_u20());
x1 = x1 / (one + x1.neg().exp_u20());
x0 = x0 * y0;
x1 = x1 * y1;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
}
} // anonymous namespace
template <typename scalar_t>
void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 1: intermediate_cache0 = hidden_states @ w1
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
const int64_t NB = div_up(2 * N, BLOCK_N);
int64_t scale_size_N = div_up(2 * N, block_size_N);
int64_t scale_size_K = div_up(K, block_size_K);
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const int64_t stride_e = 2 * N * K;
const int64_t stride_n = K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb];
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, input + index * K, K);
}
const int64_t offset = offsets[mb];
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ 2 * N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
}
if (is_brgemm_used) {
at::native::cpublas::brgemm_release();
}
});
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
silu_and_mul_stub(
ic1 + m * N,
ic0 + m * 2 * N,
ic0 + m * 2 * N + N,
N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [E, K, N] as [E, OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
scale_size_N = div_up(K, block_size_N);
scale_size_K = div_up(N, block_size_K);
const int64_t stride_e2 = OC * IC;
const int64_t stride_oc = IC;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
// B shape [IC, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m];
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
}
if (is_brgemm_used) {
at::native::cpublas::brgemm_release();
}
});
// stage 3: out = intermediate_cache2.sum(dim=1)
// from [M, topk, K] to [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
}
});
}
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
template void fused_experts_fp8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \
TYPE* __restrict__ ic2, \
TYPE* __restrict__ A_tmp, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
int64_t block_size_N, \
int64_t block_size_K, \
const float* __restrict__ topk_weights, \
const int32_t* __restrict__ sorted_ids, \
const int32_t* __restrict__ expert_ids, \
const int32_t* __restrict__ offsets, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t E, \
int64_t topk, \
int64_t num_tokens_post_pad)
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 1: intermediate_cache0 = hidden_states @ w1
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(2 * N, BLOCK_N);
int64_t scale_size_K = div_up(K, block_size_K);
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
tinygemm_kernel<scalar_t>(
/* A */ input + mb * BLOCK_M * K,
/* B */ packed_w1 + nb * BLOCK_N * K,
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ 2 * N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
silu_and_mul_stub(
ic1 + m * N,
ic0 + m * 2 * N,
ic0 + m * 2 * N + N,
N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [K, N] as [OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(K, BLOCK_N);
scale_size_K = div_up(N, block_size_K);
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ ic1 + mb * BLOCK_M * N,
/* B */ packed_w2 + nb * BLOCK_N * N,
/* C */ C,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
}
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
template void shared_expert_fp8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
int64_t block_size_N, \
int64_t block_size_K, \
const TYPE* __restrict__ fused_experts_out, \
float routed_scaling_factor, \
int64_t M, \
int64_t N, \
int64_t K)
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);

View File

@@ -0,0 +1,769 @@
// Adapted from
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
#include "common.h"
#include "vec.h"
#include "gemm.h"
// clang-format off
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
Vec data = Vec::loadu(input + d);
data.store(out + d);
}
}
template <>
inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) {
// size might be 64x + 32
std::memcpy(out, input, size * sizeof(uint8_t));
}
template <typename scalar_t>
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec weight_vec = fVec(weight);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) * weight_vec;
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] * weight);
}
}
// acc from [topk, K] to [K]
template <typename scalar_t>
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
if (topk == 1) {
// do copy for topk = 1
copy_stub(out, input, K);
} else {
// do sum for topk != 1
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= K - kVecSize; d += kVecSize) {
fVec sum_fvec0 = fVec(0.f);
fVec sum_fvec1 = fVec(0.f);
for (int t = 0; t < topk; ++t) {
bVec x_bvec = bVec::loadu(input + t * K + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec0 += x_fvec0;
sum_fvec1 += x_fvec1;
}
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
out_bvec.store(out + d);
}
for (; d < K; ++d) {
float sum_val = 0.f;
for (int t = 0; t < topk; ++t) {
sum_val += static_cast<float>(input[t * K + d]);
}
out[d] = static_cast<scalar_t>(sum_val);
}
}
}
// out = input + input2 * scale
template <typename scalar_t>
inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input,
const scalar_t* __restrict__ input2, float scale, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_vec = fVec(scale);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec x0 = fVec::loadu(input + d);
fVec x1 = fVec::loadu(input + d + fVec::size());
bVec y_bvec = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
x0 = x0 + y0 * s_vec;
x1 = x1 + y1 * s_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
}
}
/// gemm for w13
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni {
static inline void apply(
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C,
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C,
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512i va;
__m512i vb0[COLS];
__m512i vb1[COLS];
__m512i vc0[ROWS * COLS];
__m512i vc1[ROWS * COLS];
__m512i vcomp0[COLS];
__m512i vcomp1[COLS];
__m512 was;
__m512 vbs0[COLS];
__m512 vbs1[COLS];
auto loadc = [&](auto i) {
vc0[i] = _mm512_set1_epi32(0);
vc1[i] = _mm512_set1_epi32(0);
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0);
const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16);
vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16);
}
vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]);
vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto scalec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr(col == 0) {
was = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp
if constexpr (row == 0) {
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
}
__m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col]));
__m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col]));
vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col]));
vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col]));
};
Unroll<ROWS * COLS>{}(scalec);
using Vec = at::vec::Vectorized<float>;
const Vec one = Vec(1.f);
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]);
Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]);
Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]);
Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]);
// silu
x0 = x0 / (one + x0.neg().exp_u20());
x1 = x1 / (one + x1.neg().exp_u20());
// mul
x0 = x0 * y0;
x1 = x1 * y1;
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \
tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \
C + mb_start * ldc + nb_start, As + mb_start, \
Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\
K, lda, ldb, ldc);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B0,
const int8_t* __restrict__ B1,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
// pattern: 1-(2+2)-(8+8)
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 32;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch(mb_size << 4 | nb_size >> 4) {
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break;
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break;
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break;
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break;
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
/// gemm for w2
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni2 {
static inline void apply(
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512i va;
__m512i vb[COLS];
__m512i vc[ROWS * COLS];
__m512i vcomp[COLS];
__m512 was;
__m512 vbs[COLS];
auto loadc = [&](auto i) {
vc[i] = _mm512_set1_epi32(0);
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
}
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr(col == 0) {
was = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp per 2 vectors
// also load bias if any
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16);
vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
}
}
__m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col]));
x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]);
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x);
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \
tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
K, lda, ldb, ldc);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
float* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
// B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch(mb_size << 4 | nb_size >> 4) {
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break;
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break;
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break;
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break;
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
} // anonymous namespace
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad) {
// handle 2 tiles per block
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 0: quantize input to uint8, [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(
Aq_tmp + m * K,
As_tmp[m],
input + m * K,
K);
}
});
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// strides for w1: [E, 2N, K]
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
// K and N are packed for int8
const int64_t packed_K = get_row_size<int8_t>(K);
const int64_t packed_N = get_row_size<int8_t>(N);
const int64_t stride_e = 2 * N * packed_K;
const int64_t stride_n = packed_K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
alignas(64) float As[BLOCK_M];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb];
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, Aq_tmp + index * K, K);
As[m] = As_tmp[index];
}
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
const int64_t offset = offsets[mb];
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + offset * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(
Aq_tmp + m * N,
As_tmp[m],
ic1 + m * N,
N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [E, K, N] as [E, OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
const int64_t stride_e2 = OC * packed_N;
const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N;
const float* __restrict__ As = As_tmp + offsets[mb];
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
// B shape [IC, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m];
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
}
});
// stage 3: out = intermediate_cache2.sum(dim=1)
// from [M, topk, K] to [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
}
});
}
#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \
template void fused_experts_int8_kernel_impl<TYPE> ( \
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
const float* __restrict__ w1s, const float* __restrict__ w2s, \
const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \
const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \
int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad)
INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16);
INSTANTIATE_MOE_INT8_TEMPLATE(at::Half);
template <typename scalar_t>
void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K) {
// handle 2 tiles per block
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 0: quantize input to uint8, [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(
Aq_tmp + m * K,
As_tmp[m],
input + m * K,
K);
}
});
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
// K and N are packed for int8
const int64_t packed_K = get_row_size<int8_t>(K);
const int64_t packed_N = get_row_size<int8_t>(N);
const int64_t stride_n = packed_K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
// A shape [m_size, K]
const uint8_t* A = Aq_tmp + mb * BLOCK_M * K;
const float* As = As_tmp + mb * BLOCK_M;
// B shape [K, n_size] in vnni format
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(
Aq_tmp + m * N,
As_tmp[m],
ic1 + m * N,
N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [K, N] as [OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// A shape [m_size, IC]
const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N;
const float* __restrict__ As = As_tmp + mb * BLOCK_M;
// B shape [IC, n_size] in vnni format
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
}
});
}
#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \
template void shared_expert_int8_kernel_impl<TYPE> ( \
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
const float* __restrict__ w1s, const float* __restrict__ w2s, \
const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \
int64_t M, int64_t N, int64_t K)
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16);
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half);

308
csrc/cpu/sgl-kernels/vec.h Normal file
View File

@@ -0,0 +1,308 @@
// Adapted from
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
#pragma once
// clang-format off
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
#define CPU_CAPABILITY_AVX512
#endif
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
namespace {
using namespace at::vec;
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
return at::vec::convert_from_float<scalar_t>(a, b);
}
#if defined(CPU_CAPABILITY_AVX512)
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
// use native instruction for bfloat16->float32 conversion
template <>
inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
}
#define CVT_BF16_TO_FP32(a) \
_mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
#define CVT_FP16_TO_FP32(a) \
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
// this doesn't handle NaN.
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
const __m512i nonsign = _mm512_or_si512(exp, mant);
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
const __m512i combined = _mm512_or_si512(nonsign, sign);
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
}
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
// The following conversion is without denorm behavior, that is to say,
// Max subnorm : S.0000.111 = 0.875 2**(6)
// Min subnorm : S.0000.001 = 2**(9)
// 0.0019 ~ 0.0137 cannot be converted correctly.
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
auto mask = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_setzero_si512()); // mask = x & 0x7f
auto mask_nan = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
auto exponent = _mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
return (__m512bh)(_mm512_or_si512(
nonsign,
_mm512_slli_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(128)),
8))); // add sign (x & 128) << 8
}
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
__m512i lg2mant = _mm512_mask_mov_epi16(
_mm512_mask_mov_epi16(
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
_mm512_set1_epi16(2));
return (__m512bh)(_mm512_or_si512(
_mm512_maskz_mov_epi16(
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
_mm512_mask_blend_epi16(
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
_mm512_or_si512(
_mm512_and_si512(
_mm512_sllv_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
_mm512_set1_epi16(0x007f)),
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
_mm512_or_si512(
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
_mm512_slli_epi16(
_mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
7)))),
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
}
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
#ifdef SGLANG_CPU_FP8_CVT_FTZ
return cvt_e4m3_bf16_intrinsic_no_nan(a);
#else
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
#endif
}
#endif
// vector to scalar reduction
#if defined(CPU_CAPABILITY_AVX512) && 0
inline float vec_reduce_sum(const Vectorized<float>& a) {
return _mm512_reduce_add_ps(__m512(a));
}
inline float vec_reduce_max(const Vectorized<float>& a) {
return _mm512_reduce_max_ps(__m512(a));
}
#else
inline float vec_reduce_sum(const Vectorized<float>& a) {
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
}
inline float vec_reduce_max(const Vectorized<float>& a) {
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
}
#endif
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
template <typename scalar_t>
inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As,
const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
float amax = 0.f; // absolute max
for (int64_t k = 0; k < K; ++k) {
const float val = static_cast<float>(A[k]);
amax = std::max(amax, std::abs(val));
}
amax = std::max(amax, eps);
const float scale = amax / 127;
const float inv_scale = 127 / amax;
for (int64_t k = 0; k < K; ++k) {
const float val = static_cast<float>(A[k]) * inv_scale;
Aq[k] = (uint8_t)(std::round(val)) + 128;
}
As = scale;
}
#if defined(CPU_CAPABILITY_AVX512)
template <>
inline void quantize_row_int8<at::BFloat16>(uint8_t* __restrict__ Aq, float& As,
const at::BFloat16* __restrict__ A, int64_t K, float eps) {
const __m512 signBit = _mm512_set1_ps(-0.0f);
const __m512i off = _mm512_set1_epi32(128);
// K is 32x, no remainder
float amax = 0.f;
__m512 vamax0 = _mm512_set1_ps(0.f);
__m512 vamax1 = _mm512_set1_ps(0.f);
for (int64_t k = 0; k < K; k += 32) {
__m512i va = _mm512_loadu_si512((void*)(A + k));
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
}
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
amax = std::max(amax, eps);
const float scale = amax / 127;
const float inv_scale = 127 / amax;
const __m512 vd = _mm512_set1_ps(inv_scale);
for (int64_t k = 0; k < K; k += 32) {
__m512i va = _mm512_loadu_si512((void*)(A + k));
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
va0 = _mm512_mul_ps(va0, vd);
va1 = _mm512_mul_ps(va1, vd);
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
}
As = scale;
}
#endif
// transpose utils
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
#if defined(CPU_CAPABILITY_AVX512)
inline void transpose_16x16_32bit(__m512i * v) {
__m512i v1[16];
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
}
// remove warning : ignoring attributes on template argument __m512i [-Wignored-attributes]
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
// transpose from [2, 32] to [32, 2]
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
// r0: {a0, a1, ..., a31}
// r1: {b0, b1, ..., b31}
//
// d0: {a0, b0, ..., a15, b15}
// d1: {a16, b16, ..., a31, b31}
//
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
return std::make_tuple(d0, d1);
}
#pragma GCC diagnostic pop
#endif
// TODO: debug print, remove me later
template<typename scalar_t>
void print_array(scalar_t* ptr, int size) {
for (int d = 0; d < size; ++d) {
if (d % 16 == 0) { std::cout << std::endl; }
std::cout << ptr[d] << " ";
}
std::cout << std::endl;
}
} // anonymous namespace

818
csrc/cpu/shm.cpp Normal file
View File

@@ -0,0 +1,818 @@
#include "cpu/cpu_types.hpp"
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
namespace {
#define MAX_SHM_RANK_NUM 8
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
#define MIN_THREAD_PROCESS_SIZE (256)
#define MAX_P2P_SEND_TENSOR_NUM 8
template <typename scalar_t>
struct KernelVecType {
using scalar_vec_t = void;
};
template <>
struct KernelVecType<float> {
using scalar_vec_t = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::BFloat16> {
using scalar_vec_t = vec_op::BF16Vec16;
};
template <>
struct KernelVecType<c10::Half> {
using scalar_vec_t = vec_op::FP16Vec16;
};
struct ThreadSHMContext {
volatile char _curr_thread_stamp[2];
volatile char _ready_thread_stamp[2];
int local_stamp_buffer_idx;
int remote_stamp_buffer_idx;
int thread_id;
int thread_num;
int rank;
int group_size;
size_t _spinning_count;
int swizzled_ranks[MAX_SHM_RANK_NUM];
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
size_t _thread_buffer_mask[2];
char _padding2[40];
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
const int group_size, void* thread_shm_ptr)
: local_stamp_buffer_idx(0),
remote_stamp_buffer_idx(0),
thread_id(thread_id),
thread_num(thread_num),
rank(rank),
group_size(group_size),
_spinning_count(0) {
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
TORCH_CHECK((size_t)this % 64 == 0);
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
_curr_thread_stamp[0] = 1;
_curr_thread_stamp[1] = 1;
_ready_thread_stamp[0] = 0;
_ready_thread_stamp[1] = 0;
_thread_buffer_mask[0] = 0;
_thread_buffer_mask[1] = 0;
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
shm_contexts[i] = nullptr;
thread_shm_ptrs[i] = nullptr;
swizzled_ranks[i] = (i + rank) % group_size;
}
set_context(rank, this, thread_shm_ptr);
}
void set_stamp_buffer_idx(int local, int remote) {
local_stamp_buffer_idx = local;
remote_stamp_buffer_idx = remote;
}
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
TORCH_CHECK(ptr);
TORCH_CHECK(thread_shm_ptr);
TORCH_CHECK_EQ(ptr->thread_num, thread_num);
TORCH_CHECK_EQ(ptr->thread_id, thread_id);
shm_contexts[rank] = ptr;
thread_shm_ptrs[rank] = thread_shm_ptr;
}
template <typename T>
T* get_thread_shm_ptr(int rank) {
return reinterpret_cast<T*>(
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
(PER_THREAD_SHM_BUFFER_OFFSET &
_thread_buffer_mask[local_stamp_buffer_idx]));
}
void next_buffer() {
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
}
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
void next_stamp() {
_mm_mfence();
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
}
void commit_ready_stamp() {
_mm_mfence();
_ready_thread_stamp[local_stamp_buffer_idx] =
_curr_thread_stamp[local_stamp_buffer_idx];
}
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
template <typename Cond>
void wait_for_all(Cond&& cond) {
for (int idx = 1; idx < group_size; ++idx) {
int rank = get_swizzled_rank(idx);
wait_for_one(rank, std::forward<Cond>(cond));
}
}
template <typename Cond>
void wait_for_one(int rank, Cond&& cond) {
ThreadSHMContext* rank_ctx = shm_contexts[rank];
for (;;) {
char local_curr_stamp = get_curr_stamp(local_stamp_buffer_idx);
char local_ready_stamp = get_ready_stamp(local_stamp_buffer_idx);
char rank_curr_stamp = rank_ctx->get_curr_stamp(remote_stamp_buffer_idx);
char rank_ready_stamp =
rank_ctx->get_ready_stamp(remote_stamp_buffer_idx);
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
rank_ready_stamp)) {
break;
}
++_spinning_count;
_mm_pause();
}
}
static bool check_no_buffer_conflict(char local_curr_stamp,
char local_ready_stamp,
char rank_curr_stamp,
char rank_ready_stamp) {
char temp = rank_curr_stamp + 2;
return local_curr_stamp != temp;
}
static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp,
char rank_curr_stamp, char rank_ready_stamp) {
char temp = local_curr_stamp + 1;
return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp);
}
std::string to_string() const {
std::stringstream ss;
ss << "SHMContext:";
ss << "\nrank: " << rank;
ss << "\ngroup_size: " << group_size;
ss << "\nthread_num: " << thread_num;
ss << "\nthread_id: " << thread_id;
ss << "\nshm_ctx_stat_loop_seq: [";
for (int i = 0; i < group_size; ++i) {
ss << swizzled_ranks[i] << ", ";
}
ss << "]";
ss << "\nshm_contexts: [";
for (int i = 0; i < group_size; ++i) {
if (shm_contexts[i]) {
ss << shm_contexts[i]->rank << ", ";
}
}
ss << "]";
return ss.str();
}
};
class SHMManager {
public:
explicit SHMManager(const std::string& name, const int rank,
const int group_size)
: _rank(rank),
_group_size(group_size),
_thread_num(omp_get_max_threads()),
_shm_names({""}),
_shared_mem_ptrs({nullptr}),
_shm_ctx(nullptr) {
_shm_names[rank] = get_shm_name(name, rank);
_shared_mem_ptrs[rank] = init_shm(rank);
_shm_ctx = reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank]);
for (int i = 0; i < _thread_num; ++i) {
ThreadSHMContext* ctx = new (_shm_ctx + i)
ThreadSHMContext(i, _thread_num, _rank, _group_size,
compute_thread_shm_ptr(_shm_ctx, i));
}
}
void join(const std::string& name) {
for (int rank_idx = 0; rank_idx < _group_size; ++rank_idx) {
if (rank_idx != _rank) {
TORCH_CHECK(_shm_names[rank_idx].empty());
TORCH_CHECK(_shared_mem_ptrs[rank_idx] == nullptr);
_shm_names[rank_idx] = get_shm_name(name, rank_idx);
_shared_mem_ptrs[rank_idx] = init_shm(rank_idx);
ThreadSHMContext* target_ctx =
reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank_idx]);
for (int thread_idx = 0; thread_idx < _thread_num; ++thread_idx) {
_shm_ctx[thread_idx].set_context(
rank_idx, target_ctx + thread_idx,
compute_thread_shm_ptr(target_ctx, thread_idx));
}
}
}
}
~SHMManager() { destroy_shm(); }
ThreadSHMContext* get_shm_ctx() const { return _shm_ctx; }
static std::string get_shm_name(const std::string& name, int rank) {
return name + "_" + std::to_string(rank);
}
static int64_t create_singleton_instance(const std::string& name,
const int group_size,
const int rank) {
std::lock_guard<std::mutex> guard(SingletonInstancesLock);
SingletonInstances.emplace_back(
std::make_unique<SHMManager>(name, rank, group_size));
return static_cast<int64_t>(SingletonInstances.size() - 1);
}
static SHMManager* get_singleton_instance(int64_t handle) {
return SingletonInstances[handle].get();
}
protected:
static std::vector<std::unique_ptr<SHMManager>> SingletonInstances;
static std::mutex SingletonInstancesLock;
private:
static size_t round_to_alignment(size_t num) {
return ((num + 63) / 64) * 64;
}
int8_t* compute_thread_shm_ptr(ThreadSHMContext* ctx, int thread_id) {
int8_t* thread_shm_ptr =
reinterpret_cast<int8_t*>(ctx) +
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
return thread_shm_ptr +
thread_id * round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES);
}
size_t compute_shm_size() {
const size_t rounded_rank_buffer_size =
round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES) * _thread_num;
const size_t rounded_thread_shm_ctx_size =
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
const size_t shm_size =
rounded_thread_shm_ctx_size + rounded_rank_buffer_size;
return shm_size;
}
void* init_shm(int target_rank) {
const std::string& shm_name = _shm_names[target_rank];
const int local_rank = _rank;
const size_t shm_size = compute_shm_size();
int fd = -1;
if (local_rank == target_rank) {
fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR,
S_IRUSR | S_IWUSR);
if (fd == -1)
TORCH_CHECK(false, "create shm in SHMManager failed. errno: " +
std::to_string(errno));
if (ftruncate(fd, shm_size) == -1)
TORCH_CHECK(false, "ftruncate in SHMManager failed. errno: " +
std::to_string(errno));
} else {
fd = shm_open(shm_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
if (fd == -1)
TORCH_CHECK(false, "open shm in SHMManager failed. errno: " +
std::to_string(errno));
}
void* shm_ptr = mmap(nullptr, shm_size, PROT_READ | PROT_WRITE,
MAP_SHARED | MAP_POPULATE, fd, 0);
if (shm_ptr == MAP_FAILED) {
TORCH_CHECK(false,
"mmap in SHMManager failed. errno: " + std::to_string(errno));
}
if (close(fd) != 0) {
TORCH_CHECK(
false, "close in SHMManager failed. errno: " + std::to_string(errno));
}
TORCH_CHECK((size_t)shm_ptr % 64 == 0);
return shm_ptr;
}
void destroy_shm() {
std::stringstream ss;
ss << "local rank " << _rank << ": [";
for (int thread_id = 0; thread_id < _thread_num; ++thread_id) {
ss << _shm_ctx[thread_id]._spinning_count << ", ";
}
ss << "]\n";
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
if (_shared_mem_ptrs[i] != nullptr) {
munmap(_shared_mem_ptrs[i], compute_shm_size());
}
if (!_shm_names[i].empty()) {
shm_unlink(_shm_names[i].c_str());
}
}
}
int _rank;
int _group_size;
int _thread_num;
std::array<std::string, MAX_SHM_RANK_NUM> _shm_names;
std::array<void*, MAX_SHM_RANK_NUM> _shared_mem_ptrs;
ThreadSHMContext* _shm_ctx;
};
namespace shm_cc_ops {
template <typename scalar_t, typename F>
void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
int thread_num = ctx->thread_num;
int64_t total_bytes = elem_num * sizeof(scalar_t);
int64_t total_units_num =
(total_bytes + MIN_THREAD_PROCESS_SIZE - 1) / MIN_THREAD_PROCESS_SIZE;
int64_t per_thread_units_num =
(total_units_num + thread_num - 1) / thread_num;
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
int64_t max_per_thread_iteration_elem_num =
(PER_THREAD_SHM_BUFFER_BYTES >> 1) /
sizeof(scalar_t); // Note: double buffer
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
#pragma omp parallel for schedule(static, 1)
for (int i = 0; i < thread_num; ++i) {
int64_t offset = i * per_thread_elem_num;
int64_t end = std::min(elem_num, offset + per_thread_elem_num);
int64_t curr_elem_num =
std::min(max_per_thread_iteration_elem_num, end - offset);
ThreadSHMContext* thread_ctx = ctx + i;
bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num);
while (curr_elem_num > 0) {
inner_func(thread_ctx, offset, curr_elem_num, fast_mode);
thread_ctx->next_stamp();
thread_ctx->next_buffer();
offset += max_per_thread_iteration_elem_num;
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
}
}
}
void reset_threads_stamp_buffer_idx(ThreadSHMContext* ctx, int local,
int remote) {
int thread_num = ctx->thread_num;
for (int i = 0; i < thread_num; ++i) {
ThreadSHMContext* thread_ctx = ctx + i;
thread_ctx->set_stamp_buffer_idx(local, remote);
}
}
}; // namespace shm_cc_ops
namespace shm_cc_ops {
void memcpy_from_shm(void* dst, void* src, const int64_t bytes) {
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
int64_t i = 0;
#pragma GCC unroll 4
for (; i < aligned_bytes; i += 64) {
vec_op::INT8Vec64 data(
true, (int8_t*)src + i); // stream loading shm to avoid caching
data.save((int8_t*)dst + i);
}
if (aligned_bytes < bytes) {
vec_op::INT8Vec64 data(true, (int8_t*)src + aligned_bytes);
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
}
}
void memcpy_to_shm(void* dst, void* src, const int64_t bytes) {
#pragma GCC unroll 4
for (int64_t i = 0; i < bytes; i += 64) {
vec_op::INT8Vec64 data((int8_t*)src + i);
data.nt_save((int8_t*)dst + i);
}
}
void memcpy(void* dst, void* src, const int64_t bytes) {
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
int64_t i = 0;
#pragma GCC unroll 4
for (; i < aligned_bytes; i += 64) {
vec_op::INT8Vec64 data((int8_t*)src + i);
data.save((int8_t*)dst + i);
}
if (aligned_bytes < bytes) {
vec_op::INT8Vec64 data((int8_t*)src + aligned_bytes);
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
}
}
template <typename scalar_t, int RANKS>
void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
size_t elem_num) {
CPU_KERNEL_GUARD_IN(all_reduce_sum_impl)
using vec_t = typename KernelVecType<scalar_t>::scalar_vec_t;
constexpr int64_t vec_elem_num = vec_t::get_elem_num();
const int worldsize = ctx->group_size;
shm_cc_ops::shm_cc_loop<scalar_t>(
ctx, elem_num,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num, bool fast_mode) {
int rank = thread_ctx->rank;
scalar_t* thread_shm_ptr =
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
scalar_t* thread_data_ptr = data + data_offset;
int64_t thread_data_elem_num = data_elem_num * sizeof(scalar_t);
scalar_t* remote_data_ptrs[RANKS - 1];
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
remote_data_ptrs[idx] = thread_ctx->get_thread_shm_ptr<scalar_t>(
thread_ctx->get_swizzled_rank(idx + 1));
});
if (!fast_mode) {
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
}
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
thread_data_elem_num);
thread_ctx->commit_ready_stamp();
int64_t aligned_data_elem_num =
(data_elem_num / vec_elem_num) * vec_elem_num;
int64_t i = 0;
thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready);
#pragma GCC unroll 4
for (; i < aligned_data_elem_num; i += vec_elem_num) {
vec_t local_data(thread_data_ptr + i); // load from cache
vec_op::FP32Vec16 local_data_fp32(local_data);
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
vec_t remote_data(
true, remote_data_ptrs[idx] + i); // stream load from shm
vec_op::FP32Vec16 remote_data_fp32(remote_data);
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
});
vec_t reduced_data(local_data_fp32);
reduced_data.save(thread_data_ptr + i);
}
if (i < data_elem_num) {
vec_t local_data(thread_data_ptr + i); // load from cache
vec_op::FP32Vec16 local_data_fp32(local_data);
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
vec_t remote_data(
true, remote_data_ptrs[idx] + i); // stream load from shm
vec_op::FP32Vec16 remote_data_fp32(remote_data);
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
});
vec_t reduced_data(local_data_fp32);
reduced_data.save(thread_data_ptr + i,
data_elem_num - aligned_data_elem_num);
}
});
return;
}
}; // namespace shm_cc_ops
std::vector<std::unique_ptr<SHMManager>> SHMManager::SingletonInstances = {};
std::mutex SHMManager::SingletonInstancesLock = {};
template <typename scalar_t>
void shm_allreduce_sum(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num) {
switch (ctx->group_size) {
case 2:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 2>(ctx, data, elem_num);
break;
case 3:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 3>(ctx, data, elem_num);
break;
case 4:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 4>(ctx, data, elem_num);
break;
case 8:
shm_cc_ops::all_reduce_sum_impl<scalar_t, 8>(ctx, data, elem_num);
break;
default:
TORCH_CHECK(false,
"Invalid world size: " + std::to_string(ctx->group_size));
}
}
template <typename scalar_t>
void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
scalar_t** outputs, const int dst) {
CPU_KERNEL_GUARD_IN(shm_gather_impl)
const int worldsize = ctx->group_size;
TORCH_CHECK_LT(dst, worldsize);
shm_cc_ops::shm_cc_loop<scalar_t>(
ctx, elem_num,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num, bool fast_mode) {
int rank = thread_ctx->rank;
scalar_t* thread_shm_ptr =
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
if (!fast_mode) {
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
}
shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset,
data_elem_num * sizeof(scalar_t));
thread_ctx->commit_ready_stamp();
if (rank == dst) {
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
data_elem_num * sizeof(scalar_t));
for (int i = 1; i < worldsize; ++i) {
int src_rank = thread_ctx->get_swizzled_rank(i);
scalar_t* src_ptr =
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
thread_ctx->wait_for_one(src_rank,
ThreadSHMContext::check_stamp_ready);
shm_cc_ops::memcpy(dst_ptr, src_ptr,
data_elem_num * sizeof(scalar_t));
}
}
});
return;
}
struct MemPiece {
void* ptr;
int64_t size;
template <typename T>
T* data_ptr() {
return reinterpret_cast<T*>(ptr);
}
};
struct TensorListMeta {
int64_t tensor_bytes[MAX_P2P_SEND_TENSOR_NUM];
torch::ScalarType tensor_types[MAX_P2P_SEND_TENSOR_NUM];
int64_t tensor_num;
int64_t total_bytes;
TensorListMeta() : tensor_num(0), total_bytes(0) {
static_assert(sizeof(TensorListMeta) % 64 == 0);
static_assert(sizeof(TensorListMeta) <
MIN_THREAD_PROCESS_SIZE); // To ensure the metadata always
// hold by the thread 0
for (int i = 0; i < MAX_P2P_SEND_TENSOR_NUM; ++i) {
tensor_bytes[i] = 0;
tensor_ptrs[i] = nullptr;
tensor_types[i] = torch::ScalarType::Undefined;
}
}
// For send and recv
void bind_tensor_list(std::vector<torch::Tensor>& tensor_list) {
TORCH_CHECK(tensor_types[0] == torch::ScalarType::Undefined,
"Re-bind TensorListMeta is not allowed.")
TORCH_CHECK_LE(tensor_list.size(), MAX_P2P_SEND_TENSOR_NUM);
tensor_num = tensor_list.size();
int64_t bytes_sum = 0;
for (int i = 0; i < tensor_list.size(); ++i) {
torch::Tensor& t = tensor_list[i];
TORCH_CHECK(t.is_contiguous());
tensor_bytes[i] = t.nbytes();
tensor_types[i] = t.scalar_type();
tensor_ptrs[i] = t.data_ptr();
bytes_sum += t.nbytes();
}
total_bytes = bytes_sum;
}
// For recv
std::vector<torch::Tensor> generate_tensor_list() {
std::vector<torch::Tensor> tensor_list;
tensor_list.reserve(tensor_num);
for (int i = 0; i < tensor_num; ++i) {
int64_t bytes = tensor_bytes[i];
auto type = tensor_types[i];
int64_t elem_bytes = torch::elementSize(type);
TORCH_CHECK_EQ(bytes % elem_bytes, 0);
int64_t elem_num = bytes / elem_bytes;
auto options = torch::TensorOptions().dtype(type).device(torch::kCPU);
tensor_list.emplace_back(torch::empty({elem_num}, options));
}
return tensor_list;
}
MemPiece get_data(int64_t offset) {
for (int i = 0; i < tensor_num; ++i) {
if (offset < tensor_bytes[i]) {
return {reinterpret_cast<int8_t*>(tensor_ptrs[i]) + offset,
tensor_bytes[i] - offset};
}
offset -= tensor_bytes[i];
}
return {nullptr, 0};
}
private:
void* tensor_ptrs[MAX_P2P_SEND_TENSOR_NUM];
int8_t _padding[40];
};
void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
const std::vector<torch::Tensor>& tensor_list) {
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
std::vector<torch::Tensor> tensor_list_with_metadata;
tensor_list_with_metadata.reserve(1 + tensor_list.size());
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
tensor_list_with_metadata.emplace_back(
torch::empty({sizeof(TensorListMeta)}, options));
tensor_list_with_metadata.insert(tensor_list_with_metadata.end(),
tensor_list.begin(), tensor_list.end());
torch::Tensor& metadata_tensor = tensor_list_with_metadata[0];
TORCH_CHECK_EQ(metadata_tensor.nbytes(), sizeof(TensorListMeta));
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
metadata->bind_tensor_list(tensor_list_with_metadata);
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 0, 1);
shm_cc_ops::shm_cc_loop<int8_t>(
ctx, metadata->total_bytes,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num, bool fast_mode) {
int rank = thread_ctx->rank;
int64_t curr_shm_offset = 0;
thread_ctx->wait_for_one(dst,
ThreadSHMContext::check_no_buffer_conflict);
while (curr_shm_offset < data_elem_num) {
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
shm_cc_ops::memcpy(
thread_ctx->get_thread_shm_ptr<int8_t>(rank) + curr_shm_offset,
frag.ptr, frag.size);
curr_shm_offset += frag.size;
}
thread_ctx->commit_ready_stamp();
});
}
std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
int64_t src) {
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list_impl)
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
torch::Tensor metadata_tensor =
torch::empty({sizeof(TensorListMeta)}, options);
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 1, 0);
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
ctx->get_thread_shm_ptr<void>(src),
sizeof(TensorListMeta));
TensorListMeta* src_metadata =
reinterpret_cast<TensorListMeta*>(metadata_tensor.data_ptr());
std::vector<torch::Tensor> tensor_list_with_metadata =
src_metadata->generate_tensor_list();
TensorListMeta metadata;
metadata.bind_tensor_list(tensor_list_with_metadata);
TORCH_CHECK_EQ(metadata.tensor_num, src_metadata->tensor_num);
TORCH_CHECK_EQ(metadata.total_bytes, src_metadata->total_bytes);
shm_cc_ops::shm_cc_loop<int8_t>(
ctx, metadata.total_bytes,
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
int64_t data_elem_num, bool fast_mode) {
thread_ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
int64_t curr_shm_offset = 0;
while (curr_shm_offset < data_elem_num) {
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
shm_cc_ops::memcpy(
frag.ptr,
thread_ctx->get_thread_shm_ptr<int8_t>(src) + curr_shm_offset,
frag.size);
curr_shm_offset += frag.size;
}
});
std::vector<torch::Tensor> tensor_list;
tensor_list.reserve(metadata.tensor_num - 1);
tensor_list.insert(tensor_list.begin(), tensor_list_with_metadata.begin() + 1,
tensor_list_with_metadata.end());
return tensor_list;
}
} // namespace
void shm_gather(int64_t handle, torch::Tensor& data,
const std::optional<std::vector<torch::Tensor>>& outputs,
int64_t dst) {
TORCH_CHECK(data.is_contiguous())
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_gather_impl", [&] {
CPU_KERNEL_GUARD_IN(shm_gather_impl)
if (outputs.has_value()) {
TORCH_CHECK_LE(outputs->size(), MAX_SHM_RANK_NUM);
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
for (int i = 0; i < outputs->size(); ++i) {
output_ptrs[i] = outputs->at(i).data_ptr<scalar_t>();
}
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
dst);
} else {
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
data.data_ptr<scalar_t>(), data.numel(), (scalar_t**)(0),
dst);
}
CPU_KERNEL_GUARD_OUT(shm_gather_impl)
});
}
void shm_all_gather(int64_t handle, const torch::Tensor& data,
torch::Tensor& output) {
TORCH_CHECK(data.is_contiguous())
TORCH_CHECK(output.is_contiguous())
const int64_t input_elem_num = data.numel();
const int64_t output_elem_num = output.numel();
TORCH_CHECK_EQ(output_elem_num % input_elem_num, 0);
const int world_size = output_elem_num / input_elem_num;
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_all_gather_impl", [&] {
CPU_KERNEL_GUARD_IN(shm_all_gather_impl)
auto ctx = SHMManager::get_singleton_instance(handle)->get_shm_ctx();
TORCH_CHECK_EQ(ctx->group_size, world_size);
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
for (int i = 0; i < world_size; ++i) {
output_ptrs[i] = output.data_ptr<scalar_t>() + i * input_elem_num;
}
shm_gather_impl(ctx, data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
ctx->rank);
CPU_KERNEL_GUARD_OUT(shm_all_gather_impl)
});
}
void shm_allreduce(int64_t handle, torch::Tensor& data) {
TORCH_CHECK(data.is_contiguous())
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_allreduce_sum", [&] {
CPU_KERNEL_GUARD_IN(shm_allreduce_sum)
shm_allreduce_sum(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
data.data_ptr<scalar_t>(), data.numel());
CPU_KERNEL_GUARD_OUT(shm_allreduce_sum)
});
}
void shm_send_tensor_list(int64_t handle,
const std::vector<torch::Tensor>& tensor_list,
int64_t dst) {
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
shm_send_tensor_list_impl(
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst,
tensor_list);
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
}
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list)
auto tensor_list = shm_recv_tensor_list_impl(
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), src);
CPU_KERNEL_GUARD_OUT(shm_recv_tensor_list)
return tensor_list;
}
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank) {
return SHMManager::create_singleton_instance(name, group_size, rank);
}
std::string join_shm_manager(int64_t handle, const std::string& name) {
auto shm_manager = SHMManager::get_singleton_instance(handle);
TORCH_CHECK(shm_manager);
shm_manager->join(name);
return shm_manager->get_shm_ctx()->to_string();
}

314
csrc/cpu/torch_bindings.cpp Normal file
View File

@@ -0,0 +1,314 @@
#include "cache.h"
#include "ops.h"
#include "core/registration.h"
#include <torch/library.h>
std::string init_cpu_threads_env(const std::string& cpu_ids);
void release_dnnl_matmul_handler(int64_t handler);
int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
const torch::Tensor& b_scales,
at::ScalarType output_type,
bool dynamic_act_quant, bool use_azp,
int64_t primitive_cache_size);
void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& a_scales,
const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& azp_adj,
const std::optional<torch::Tensor>& bias,
int64_t handler);
int64_t create_onednn_mm_handler(const torch::Tensor& b,
int64_t primitive_cache_size);
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& bias, int64_t handler);
bool is_onednn_acl_supported();
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank);
std::string join_shm_manager(int64_t handle, const std::string& name);
void shm_allreduce(int64_t handle, torch::Tensor& data);
void shm_gather(int64_t handle, torch::Tensor& data,
const std::optional<std::vector<torch::Tensor>>& outputs,
int64_t dst);
void shm_all_gather(int64_t handle, const torch::Tensor& data,
torch::Tensor& output);
void shm_send_tensor_list(int64_t handle,
const std::vector<torch::Tensor>& tensor_list,
int64_t dst);
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias,
bool is_vnni);
at::Tensor convert_weight_packed(at::Tensor& weight);
at::Tensor fused_experts_cpu(
at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
bool use_int8_w8a8, bool use_fp8_w8a16,
const std::optional<at::Tensor>& w1_scale,
const std::optional<at::Tensor>& w2_scale,
const std::optional<std::vector<int64_t>> block_size,
const std::optional<at::Tensor>& a1_scale,
const std::optional<at::Tensor>& a2_scale, bool is_vnni);
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni);
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split);
void cpu_attn_reshape_and_cache(const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
const torch::Tensor& slot_mapping,
const std::string& isa);
void cpu_attention_with_kv_cache(
const torch::Tensor& query, const torch::Tensor& key_cache,
const torch::Tensor& value_cache, torch::Tensor& output,
const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes,
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, const double softcap,
const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux);
// Note: just for avoiding importing errors
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
torch::Tensor& output, const torch::Tensor& scales,
const std::optional<torch::Tensor>& zeros,
const std::optional<torch::Tensor>& g_idx,
const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
ops.def(
"dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
"int group_size, bool apply_router_weight_on_input, int activation_kind"
") -> Tensor");
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCPU, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCPU, &gelu_quick);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCPU, &rms_norm);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
// Quantization
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
defined(__powerpc64__)
// Helper function to release oneDNN handlers
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
&release_dnnl_matmul_handler);
// Create oneDNN GEMM handler
ops.def(
"create_onednn_mm_handler(Tensor b, int "
"primitive_cache_size) -> int",
&create_onednn_mm_handler);
// oneDNN GEMM
ops.def(
"onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
"int handler) -> ()");
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
// Check if oneDNN was built with ACL backend
ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);
// Create oneDNN W8A8 handler
ops.def(
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
"output_type, bool dynamic_act_quant, bool use_azp, int "
"primitive_cache_size) -> int",
&create_onednn_scaled_mm_handler);
// oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
ops.def(
"onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
"Tensor? azp_adj, Tensor? bias, int handler) -> ()");
ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
#endif
// SHM CCL
#ifdef __AVX512F__
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
&init_shm_manager);
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
ops.def(
"shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
"()");
ops.impl("shm_gather", torch::kCPU, &shm_gather);
ops.def(
"shm_all_gather(int handle, Tensor data, Tensor! output) -> "
"()");
ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
ops.def(
"shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
"()");
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
&shm_recv_tensor_list);
#endif
// sgl-kernels
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
ops.def(
"weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
"bias, bool is_vnni) -> Tensor");
ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
ops.def(
"fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
"topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
"use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
"block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
"Tensor");
ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
ops.def(
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant);
#endif
// CPU attention kernels
ops.def(
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
"int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
"query_start_loc, bool casual, int window_size, str isa_hint, bool "
"enable_kv_split) -> Tensor",
&get_scheduler_metadata);
ops.def(
"cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
"isa) -> ()",
&cpu_attn_reshape_and_cache);
ops.def(
"cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
"value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
"seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
&cpu_attention_with_kv_cache);
// placeholders
ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
// WNA16
#if defined(__AVX512F__)
ops.def(
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
"pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
// CPU utils
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
cpu_ops.def(
"mla_decode_kvcache("
" Tensor! out, Tensor query, Tensor kv_cache,"
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

141
csrc/cpu/utils.cpp Normal file
View File

@@ -0,0 +1,141 @@
#ifndef VLLM_NUMA_DISABLED
#include <numa.h>
#include <unistd.h>
#include <string>
#include <sched.h>
#endif
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
#include <unistd.h>
#include <sys/syscall.h>
#define gettid() syscall(SYS_gettid)
#endif
#include "cpu_types.hpp"
#ifdef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) {
return std::string(
"Warning: NUMA is not enabled in this build. `init_cpu_threads_env` has "
"no effect to setup thread affinity.");
}
#endif
#ifndef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size);
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
int i = 0;
while (group_mask) {
if (group_mask & 1) {
omp_cpu_ids.emplace_back(offset + i);
}
++i;
group_mask >>= 1;
}
}
// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
std::set<int> node_ids;
for (const auto& cpu_id : omp_cpu_ids) {
int node_id = numa_node_of_cpu(cpu_id);
if (node_id != -1) {
node_ids.insert(node_id);
}
if (node_id != mem_node_id) {
TORCH_WARN("CPU ", cpu_id, " is on NUMA node ", node_id, ", but CPU ",
omp_cpu_ids.front(), " is on NUMA node ", mem_node_id,
". All CPUs should be on the same NUMA node for optimal "
"performance. Memory will be bound to NUMA node ",
mem_node_id, ".");
}
}
// Concatenate all node_ids into a single comma-separated string
if (!node_ids.empty()) {
std::string node_ids_str;
for (const int node_id : node_ids) {
if (!node_ids_str.empty()) {
node_ids_str += ",";
}
node_ids_str += std::to_string(node_id);
}
bitmask* mask = numa_parse_nodestring(node_ids_str.c_str());
bitmask* src_mask = numa_get_membind();
int pid = getpid();
if (mask && src_mask) {
// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_WARN("numa_migrate_pages failed. errno: " +
std::to_string(errno));
}
// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
numa_free_nodemask(mask);
numa_free_nodemask(src_mask);
} else {
TORCH_WARN("numa_parse_nodestring or numa_get_membind failed. errno: " +
std::to_string(errno));
}
}
}
// OMP threads binding
omp_set_num_threads((int)omp_cpu_ids.size());
torch::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
std::vector<std::pair<int, int>> thread_core_mapping;
thread_core_mapping.reserve(omp_cpu_ids.size());
omp_lock_t writelock;
omp_init_lock(&writelock);
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(omp_cpu_ids[i], &mask);
int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
if (ret == -1) {
TORCH_CHECK(false,
"sched_setaffinity failed. errno: " + std::to_string(errno));
}
omp_set_lock(&writelock);
thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]);
omp_unset_lock(&writelock);
}
omp_destroy_lock(&writelock);
numa_free_nodemask(omp_cpu_mask);
std::stringstream ss;
ss << "OMP threads binding of Process " << getpid() << ":\n";
std::sort(thread_core_mapping.begin(), thread_core_mapping.end(),
[](auto&& a, auto&& b) { return a.second < b.second; });
for (auto&& item : thread_core_mapping) {
ss << "\t"
<< "OMP tid: " << item.first << ", core " << item.second << "\n";
}
return ss.str();
}
#endif

73
csrc/cpu/utils.hpp Normal file
View File

@@ -0,0 +1,73 @@
#ifndef UTILS_HPP
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
namespace cpu_utils {
enum class ISA { AMX, VEC };
template <typename T>
struct VecTypeTrait {
using vec_t = void;
};
template <>
struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
};
struct Counter {
std::atomic<int64_t> counter;
char _padding[56];
Counter() : counter(0) {}
void reset_counter() { counter.store(0); }
int64_t acquire_counter() { return counter++; }
};
inline int64_t get_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
} // namespace cpu_utils
#endif