CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
@@ -11,19 +12,144 @@ namespace {
|
||||
// 4. provide amx kernel for index_gemm_kernel_nn when M = 16
|
||||
//
|
||||
|
||||
inline void fill_stub(float* __restrict__ out, float val, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const Vec data_vec(val);
|
||||
at::vec::map<float>([data_vec](Vec out) { return out = data_vec; }, out, out, size);
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// key: from [N, 32] to [32/2, N, 2]
|
||||
// val: from [N, 32] to [N/2, 32, 2]
|
||||
template <typename scalar_t, typename index_t>
|
||||
inline void pack_vnni_Nx32(
|
||||
scalar_t* __restrict__ dst0,
|
||||
scalar_t* __restrict__ dst1,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst0,
|
||||
int ld_dst1,
|
||||
bool convert_v) {
|
||||
__m512i vinputs[16];
|
||||
int n = 0;
|
||||
for (; n < N; ++n) {
|
||||
vinputs[n] = _mm512_loadu_si512(src + ind[n] * ld_src);
|
||||
}
|
||||
// padding with zero to avoid uninitialized vectors
|
||||
for (; n < 16; ++n) {
|
||||
vinputs[n] = _mm512_set1_epi32(0);
|
||||
}
|
||||
|
||||
// pack value, skip 64 elems for deepseek
|
||||
// handle 2 vectors at a time from [2, 32] to [32, 2]
|
||||
if (convert_v) {
|
||||
for (int n = 0; n < 16; n += 2) {
|
||||
__m512i d0, d1;
|
||||
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[n], vinputs[n + 1]);
|
||||
_mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2, d0);
|
||||
_mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2 + 32, d1);
|
||||
}
|
||||
}
|
||||
|
||||
// pack key
|
||||
transpose_16x16_32bit(vinputs);
|
||||
|
||||
const __mmask16 vmask = (1 << N) - 1;
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
_mm512_mask_storeu_epi32(dst0 + k * ld_dst0 * 2, vmask, vinputs[k]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// [NOTE]: MLA vnni format conversion
|
||||
//
|
||||
// here we apply same strategy as `FlashMLA`:
|
||||
// each kv_cache is loaded once and packed twice (L2 cache hit)
|
||||
//
|
||||
// * for key: from [N, K/2, 2] to [K/2, N, 2]
|
||||
// * for value: from [N/2, 2, Kv] to [N/2, Kv, 2]
|
||||
//
|
||||
template <typename scalar_t, typename index_t>
|
||||
void pack_vnni(
|
||||
scalar_t* __restrict__ dst0,
|
||||
scalar_t* __restrict__ dst1,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int K,
|
||||
int Kv,
|
||||
int ld_src,
|
||||
int ld_dst0,
|
||||
int ld_dst1) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
const int NB = div_up(N, 16);
|
||||
const int KB = K / 32; // no remainder
|
||||
const int KBv = Kv / 32; // no remainder
|
||||
|
||||
for (int nb = 0; nb < NB; ++nb) {
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
// handle 16x512bits each block
|
||||
int nb_size = std::min(N - nb * 16, 16);
|
||||
pack_vnni_Nx32<scalar_t, index_t>(
|
||||
/* dst0 */ dst0 + ((kb * 32) >> 1) * ld_dst0 * 2 + nb * 16 * 2,
|
||||
/* dst1 */ dst1 + ((nb * 16) >> 1) * ld_dst1 * 2 + kb * 32 * 2,
|
||||
/* src */ src + kb * 32,
|
||||
/* ind */ ind + nb * 16,
|
||||
/* N */ nb_size,
|
||||
/* ld_src */ ld_src,
|
||||
/* ld_dst0 */ ld_dst0,
|
||||
/* ld_dst1 */ ld_dst1,
|
||||
/* cvt_v */ kb < KBv);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int n = 0; n < N; ++n) {
|
||||
index_t index = ind[n];
|
||||
for (int k = 0; k < K / 2; ++k) {
|
||||
for (int d = 0; d < 2; ++d) {
|
||||
dst0[k * ld_dst0 * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
// from [N/2, 2, K] to [N/2, K, 2]
|
||||
for (int n = 0; n < (N >> 1) * 2; n += 2) {
|
||||
index_t index0 = ind[n + 0];
|
||||
index_t index1 = ind[n + 1];
|
||||
for (int k = 0; k < Kv; ++k) {
|
||||
dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index0 * ld_src + k];
|
||||
dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 1] = src[index1 * ld_src + k];
|
||||
}
|
||||
}
|
||||
if (N % 2 != 0) {
|
||||
index_t index = ind[N - 1];
|
||||
for (int k = 0; k < Kv; ++k) {
|
||||
dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index * ld_src + k];
|
||||
dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 1] = 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void fill_stub(scalar_t* __restrict__ out, float val, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
constexpr int kVecSize = Vec::size();
|
||||
const Vec data_vec = Vec(static_cast<scalar_t>(val));
|
||||
int64_t d = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
data_vec.store(out + d);
|
||||
}
|
||||
if (size - d > 0) {
|
||||
data_vec.store(out + d, size - d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_fvec = fVec(s);
|
||||
int64_t d = 0;
|
||||
for (; d <= size - bVec::size(); d += bVec::size()) {
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
|
||||
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
@@ -37,8 +163,10 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc,
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
int64_t d = 0;
|
||||
for (; d <= size - bVec::size(); d += bVec::size()) {
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec out_bvec = bVec::loadu(src + d);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
@@ -47,6 +175,26 @@ inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ s
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int BLOCK_N>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
|
||||
static_assert(BLOCK_N % 32 == 0);
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
auto store = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
fVec a_fvec0 = fVec::loadu(input + col * 16);
|
||||
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
out_bvec.store(out + col * 16);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(store);
|
||||
}
|
||||
|
||||
// GEMM handles query @ key (indexed) x scale
|
||||
// A : [M, K]
|
||||
// B : [N, K] indexed
|
||||
@@ -619,16 +767,105 @@ void index_gemm_kernel_nn(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
void decode_attention_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
float* __restrict__ attn_logits,
|
||||
const scalar_t* __restrict__ query,
|
||||
template <typename scalar_t>
|
||||
void decode_set_kv_buffer(
|
||||
scalar_t* __restrict__ k_buffer,
|
||||
scalar_t* __restrict__ v_buffer,
|
||||
const scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ value,
|
||||
const int64_t* __restrict__ loc,
|
||||
int64_t batches,
|
||||
int64_t num_heads_kv,
|
||||
int64_t head_size,
|
||||
int64_t head_size_v,
|
||||
int64_t k_strideN,
|
||||
int64_t k_strideH,
|
||||
int64_t v_strideN,
|
||||
int64_t v_strideH,
|
||||
int64_t nk_strideN,
|
||||
int64_t nk_strideH,
|
||||
int64_t nv_strideN,
|
||||
int64_t nv_strideH,
|
||||
bool is_mla) {
|
||||
at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t bs{0}, head_kv_id{0};
|
||||
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv);
|
||||
|
||||
for (int64_t i = begin; i < end; i++) {
|
||||
int64_t loc_val = loc[bs];
|
||||
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH;
|
||||
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH;
|
||||
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
|
||||
if (!is_mla) {
|
||||
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH;
|
||||
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH;
|
||||
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
|
||||
}
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, batches, head_kv_id, num_heads_kv);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void decode_accumulate_kv_splits(
|
||||
scalar_t* __restrict__ output,
|
||||
float* __restrict__ attn_logits,
|
||||
int64_t batches,
|
||||
int64_t num_heads,
|
||||
int64_t head_size_v,
|
||||
int64_t num_kv_splits,
|
||||
int64_t l_stride1,
|
||||
int64_t l_stride2) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// parallel on [batches, num_heads]
|
||||
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
|
||||
// NB: here we use logits[b][h][0] as acc, since
|
||||
// for the first kv split (kv_id == 0):
|
||||
// m_delta = std::exp(-inf) = 0
|
||||
// e_logic = std::exp(0) = 1
|
||||
// acc = acc * m_delta + tv * e_logic = tv
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
float* __restrict__ acc = attn_logits + i * l_stride1;
|
||||
|
||||
float s_prime = 0.f;
|
||||
float m_prime = -std::numeric_limits<scalar_t>::infinity();
|
||||
|
||||
// update acc with from each kv_split
|
||||
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
|
||||
float* __restrict__ tv = acc + kv_id * l_stride2;
|
||||
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
|
||||
|
||||
float m_i = std::max(tlogic, m_prime);
|
||||
float m_delta = std::exp(m_prime - m_i);
|
||||
float e_logic = std::exp(tlogic - m_i);
|
||||
if (kv_id != 0) {
|
||||
at::vec::map2<float>(
|
||||
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
|
||||
acc,
|
||||
acc,
|
||||
tv,
|
||||
head_size_v);
|
||||
}
|
||||
|
||||
s_prime = s_prime * m_delta + e_logic;
|
||||
m_prime = m_i;
|
||||
}
|
||||
|
||||
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
|
||||
void decode_attention_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
float* __restrict__ attn_logits,
|
||||
const scalar_t* __restrict__ query,
|
||||
const scalar_t* __restrict__ k_buffer,
|
||||
const scalar_t* __restrict__ v_buffer,
|
||||
const index_t* __restrict__ req_to_token,
|
||||
const int64_t* __restrict__ req_pool_indices,
|
||||
const int64_t* __restrict__ seq_lens,
|
||||
@@ -641,38 +878,13 @@ void decode_attention_kernel_impl(
|
||||
int64_t k_strideH,
|
||||
int64_t v_strideN,
|
||||
int64_t v_strideH,
|
||||
int64_t nk_strideN,
|
||||
int64_t nk_strideH,
|
||||
int64_t nv_strideN,
|
||||
int64_t nv_strideH,
|
||||
float scaling,
|
||||
float logit_cap,
|
||||
int64_t max_num_reqs,
|
||||
int64_t max_context_len,
|
||||
int64_t max_total_num_tokens) {
|
||||
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t bs{0}, head_id{0};
|
||||
data_index_init(begin, bs, batches, head_id, num_heads);
|
||||
|
||||
for (int64_t i = begin; i < end; i++) {
|
||||
int64_t loc_val = loc[bs];
|
||||
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_id * k_strideH;
|
||||
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_id * v_strideH;
|
||||
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_id * nk_strideH;
|
||||
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_id * nv_strideH;
|
||||
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
|
||||
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, batches, head_id, num_heads);
|
||||
}
|
||||
});
|
||||
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// block length for k_buffer and v_buffer
|
||||
constexpr int64_t BLOCK_N = 256;
|
||||
|
||||
// strides
|
||||
const int64_t q_strideM = num_heads * head_size;
|
||||
const int64_t q_strideH = head_size;
|
||||
@@ -785,55 +997,209 @@ void decode_attention_kernel_impl(
|
||||
}
|
||||
});
|
||||
|
||||
// parallel on [batches, num_heads]
|
||||
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
|
||||
// NB: here we use logits[b][h][0] as acc, since
|
||||
// for the first kv split (kv_id == 0):
|
||||
// m_delta = std::exp(-inf) = 0
|
||||
// e_logic = std::exp(0) = 1
|
||||
// acc = acc * m_delta + tv * e_logic = tv
|
||||
decode_accumulate_kv_splits(
|
||||
output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
|
||||
} // MHA
|
||||
|
||||
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
|
||||
void decode_attention_mla_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
float* __restrict__ attn_logits,
|
||||
const scalar_t* __restrict__ query,
|
||||
const scalar_t* __restrict__ k_buffer,
|
||||
const scalar_t* __restrict__ v_buffer,
|
||||
const index_t* __restrict__ req_to_token,
|
||||
const int64_t* __restrict__ req_pool_indices,
|
||||
const int64_t* __restrict__ seq_lens,
|
||||
scalar_t* __restrict__ buffer,
|
||||
int64_t batches,
|
||||
int64_t num_heads,
|
||||
int64_t head_size,
|
||||
int64_t head_size_v,
|
||||
int64_t num_kv_splits,
|
||||
int64_t k_strideN,
|
||||
int64_t k_strideH,
|
||||
int64_t v_strideN,
|
||||
int64_t v_strideH,
|
||||
float scaling,
|
||||
float logit_cap,
|
||||
int64_t max_num_reqs,
|
||||
int64_t max_context_len,
|
||||
int64_t max_total_num_tokens,
|
||||
int64_t buffer_size_per_thread) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// block length for heads
|
||||
const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11);
|
||||
|
||||
// strides
|
||||
const int64_t q_strideM = num_heads * head_size;
|
||||
const int64_t q_strideH = head_size;
|
||||
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
||||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||||
const int64_t l_stride2 = head_size_v + 1;
|
||||
|
||||
TORCH_CHECK(logit_cap == 0.f, "decode MLA: expect no logit_cap.");
|
||||
|
||||
// partition the heads into blocks for parallel
|
||||
const int64_t num_blocks = div_up(num_heads, BLOCK_H);
|
||||
|
||||
// parallel on [batches, num_blocks, num_kv_splits]
|
||||
at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t bs{0}, block_id{0}, kv_id{0};
|
||||
data_index_init(begin, bs, batches, block_id, num_blocks, kv_id, num_kv_splits);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ Btmp0 = buffer + tid * buffer_size_per_thread;
|
||||
scalar_t* __restrict__ Btmp1 = Btmp0 + BLOCK_N * head_size;
|
||||
|
||||
// init Btmp1 just once for each thread to prevent NaN
|
||||
// Btmp0 is not needed as it computes full K every single time
|
||||
fill_stub(Btmp1, 0.f, BLOCK_N * head_size_v);
|
||||
|
||||
alignas(64) float s_i[BLOCK_H * BLOCK_N];
|
||||
float* __restrict__ s_delta = s_i;
|
||||
alignas(64) scalar_t s_delta2[BLOCK_H * BLOCK_N];
|
||||
|
||||
alignas(64) float s_prime[BLOCK_H];
|
||||
alignas(64) float m_prime[BLOCK_H];
|
||||
alignas(64) float m_delta[BLOCK_H];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
float* __restrict__ acc = attn_logits + i * l_stride1;
|
||||
const int64_t h_start = block_id * BLOCK_H;
|
||||
const int64_t h_end = std::min(block_id * BLOCK_H + BLOCK_H, num_heads);
|
||||
const int64_t h_size = h_end - h_start;
|
||||
|
||||
float s_prime = 0.f;
|
||||
float m_prime = -std::numeric_limits<scalar_t>::infinity();
|
||||
// get query
|
||||
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
|
||||
|
||||
// update acc with from each kv_split
|
||||
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
|
||||
float* __restrict__ tv = acc + kv_id * l_stride2;
|
||||
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
|
||||
int64_t seq_len_kv = seq_lens[bs];
|
||||
int64_t req_pool_id = req_pool_indices[bs];
|
||||
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
|
||||
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
|
||||
|
||||
float m_i = std::max(tlogic, m_prime);
|
||||
float m_delta = std::exp(m_prime - m_i);
|
||||
float e_logic = std::exp(tlogic - m_i);
|
||||
if (kv_id != 0) {
|
||||
at::vec::map2<float>(
|
||||
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
|
||||
acc,
|
||||
acc,
|
||||
tv,
|
||||
head_size_v);
|
||||
}
|
||||
const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits);
|
||||
const int64_t kv_start = kv_id * SPLIT_SIZE;
|
||||
const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv);
|
||||
|
||||
s_prime = s_prime * m_delta + e_logic;
|
||||
m_prime = m_i;
|
||||
fill_stub(s_prime, 0.f, BLOCK_H);
|
||||
fill_stub(m_prime, -std::numeric_limits<float>::infinity(), BLOCK_H);
|
||||
|
||||
// get v_prime, and init to zero
|
||||
float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2;
|
||||
for (int64_t h = 0; h < h_size; ++h) {
|
||||
fill_stub(v_prime + h * l_stride1, 0.f, head_size_v);
|
||||
}
|
||||
|
||||
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
|
||||
}
|
||||
});
|
||||
}
|
||||
// loop over K and V sequence with BLOCK_N
|
||||
for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) {
|
||||
int64_t n_size = std::min(BLOCK_N, kv_end - n);
|
||||
const int64_t padded_n_size = div_up(int(n_size), TILE_K) * TILE_K;
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
// get key and pack
|
||||
pack_vnni<scalar_t, index_t>(
|
||||
/* dst0 */ Btmp0,
|
||||
/* dst1 */ Btmp1,
|
||||
/* src */ k_buffer + /* head_kv_id */ 0 * k_strideH,
|
||||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* Kv */ head_size_v,
|
||||
/* ld_src */ k_strideN,
|
||||
/* ld_dst0 */ BLOCK_N,
|
||||
/* ld_dst1 */ head_size_v);
|
||||
|
||||
// calculate s_i <- Q @ K
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ h_size,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* lda */ q_strideH,
|
||||
/* ldb */ BLOCK_N,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ q_ptr,
|
||||
/* B */ Btmp0,
|
||||
/* C */ s_i);
|
||||
|
||||
const Vec scale_vec = Vec(scaling);
|
||||
for (int64_t h = 0; h < h_size; ++h) {
|
||||
// s_i <- s_i * scale
|
||||
at::vec::map<float>(
|
||||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
|
||||
|
||||
// m_i: max value per row
|
||||
float m_i = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size);
|
||||
m_i = std::max(m_i, m_prime[h]);
|
||||
|
||||
// m_delta <- exp(m' - m_i)
|
||||
m_delta[h] = std::exp(m_prime[h] - m_i);
|
||||
|
||||
// s_delta <- exp(s_i - m_i)
|
||||
at::vec::map<float>(
|
||||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
|
||||
|
||||
// s' <- s' * m_delta + sum(s_delta)
|
||||
s_prime[h] *= m_delta[h];
|
||||
s_prime[h] += at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size);
|
||||
|
||||
m_prime[h] = m_i;
|
||||
|
||||
// v' <- v' * m_delta
|
||||
float scale_m = m_delta[h];
|
||||
at::vec::map<float>(
|
||||
[scale_m](Vec x) { return x * Vec(scale_m); },
|
||||
v_prime + h * l_stride1,
|
||||
v_prime + h * l_stride1,
|
||||
head_size_v);
|
||||
|
||||
// pad s_delta with 0 first and then convert to scalar_t
|
||||
fill_stub(s_delta + h * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + h * BLOCK_N, s_delta + h * BLOCK_N);
|
||||
}
|
||||
|
||||
// calculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ h_size,
|
||||
/* N */ head_size_v,
|
||||
/* K */ padded_n_size, // n_size
|
||||
/* lda */ BLOCK_N,
|
||||
/* ldb */ head_size_v,
|
||||
/* ldc */ l_stride1,
|
||||
/* add_C */ true,
|
||||
/* A */ s_delta2,
|
||||
/* B */ Btmp1,
|
||||
/* C */ v_prime);
|
||||
} // loop with KV blocks
|
||||
|
||||
// only update v' when kv_split_size > 0
|
||||
if (kv_end > kv_start) {
|
||||
for (int64_t h = 0; h < h_size; ++h) {
|
||||
float s = 1 / s_prime[h];
|
||||
at::vec::map<float>(
|
||||
[s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v);
|
||||
(v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]);
|
||||
}
|
||||
}
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, batches, block_id, num_blocks, kv_id, num_kv_splits);
|
||||
}
|
||||
at::native::cpublas::brgemm_release();
|
||||
});
|
||||
|
||||
decode_accumulate_kv_splits(
|
||||
output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
|
||||
} // MLA
|
||||
|
||||
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
|
||||
void decode_attention_grouped_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
float* __restrict__ attn_logits,
|
||||
const scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ k_buffer,
|
||||
scalar_t* __restrict__ v_buffer,
|
||||
const scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ value,
|
||||
const int64_t* __restrict__ loc,
|
||||
const scalar_t* __restrict__ k_buffer,
|
||||
const scalar_t* __restrict__ v_buffer,
|
||||
const index_t* __restrict__ req_to_token,
|
||||
const int64_t* __restrict__ req_pool_indices,
|
||||
const int64_t* __restrict__ seq_lens,
|
||||
@@ -847,37 +1213,13 @@ void decode_attention_grouped_kernel_impl(
|
||||
int64_t k_strideH,
|
||||
int64_t v_strideN,
|
||||
int64_t v_strideH,
|
||||
int64_t nk_strideN,
|
||||
int64_t nk_strideH,
|
||||
int64_t nv_strideN,
|
||||
int64_t nv_strideH,
|
||||
float scaling,
|
||||
float logit_cap,
|
||||
int64_t max_num_reqs,
|
||||
int64_t max_context_len,
|
||||
int64_t max_total_num_tokens) {
|
||||
at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t bs{0}, head_kv_id{0};
|
||||
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv);
|
||||
|
||||
for (int64_t i = begin; i < end; i++) {
|
||||
int64_t loc_val = loc[bs];
|
||||
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH;
|
||||
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH;
|
||||
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH;
|
||||
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH;
|
||||
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
|
||||
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, batches, head_kv_id, num_heads_kv);
|
||||
}
|
||||
});
|
||||
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// block length for k_buffer and v_buffer
|
||||
constexpr int64_t BLOCK_N = 256;
|
||||
// block length for heads
|
||||
// we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits]
|
||||
// use smaller BLOCK_H when batches is small to utilize all cores
|
||||
@@ -960,7 +1302,7 @@ void decode_attention_grouped_kernel_impl(
|
||||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||||
s_i,
|
||||
s_i,
|
||||
n_size);
|
||||
BLOCK_H * BLOCK_N);
|
||||
}
|
||||
|
||||
// update the scaling coefficients
|
||||
@@ -1015,40 +1357,9 @@ void decode_attention_grouped_kernel_impl(
|
||||
}
|
||||
});
|
||||
|
||||
// parallel on [batches, num_heads]
|
||||
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
|
||||
// NB: same as above
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
float* __restrict__ acc = attn_logits + i * l_stride1;
|
||||
|
||||
float s_prime = 0.f;
|
||||
float m_prime = -std::numeric_limits<scalar_t>::infinity();
|
||||
|
||||
// update acc with from each kv_split
|
||||
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
|
||||
float* __restrict__ tv = acc + kv_id * l_stride2;
|
||||
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
|
||||
|
||||
float m_i = std::max(tlogic, m_prime);
|
||||
float m_delta = std::exp(m_prime - m_i);
|
||||
float e_logic = std::exp(tlogic - m_i);
|
||||
if (kv_id != 0) {
|
||||
at::vec::map2<float>(
|
||||
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
|
||||
acc,
|
||||
acc,
|
||||
tv,
|
||||
head_size_v);
|
||||
}
|
||||
|
||||
s_prime = s_prime * m_delta + e_logic;
|
||||
m_prime = m_i;
|
||||
}
|
||||
|
||||
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
|
||||
}
|
||||
});
|
||||
}
|
||||
decode_accumulate_kv_splits(
|
||||
output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
|
||||
} // GQA/MQA
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@@ -1134,19 +1445,50 @@ void decode_attention_cpu(
|
||||
"decode: expect req_pool_indices to be int64, got ",
|
||||
req_pool_indices.scalar_type());
|
||||
|
||||
// check if we have MLA here
|
||||
void* k_buffer_data = k_buffer.data_ptr();
|
||||
void* v_buffer_data = v_buffer.data_ptr();
|
||||
const bool is_mla = (k_buffer_data == v_buffer_data) && (num_heads_kv == 1) && (head_size == head_size_v + 64);
|
||||
|
||||
// block length for k_buffer and v_buffer
|
||||
constexpr int BLOCK_N = 256;
|
||||
|
||||
// buffer for packing k_cache and v_cache
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = is_mla ? BLOCK_N * head_size + BLOCK_N * head_size_v : 0;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, k_buffer.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] {
|
||||
AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] {
|
||||
// update the kv buffer
|
||||
decode_set_kv_buffer(
|
||||
(scalar_t*)k_buffer_data,
|
||||
(scalar_t*)v_buffer_data,
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
loc.data_ptr<int64_t>(),
|
||||
num_seqs,
|
||||
num_heads_kv,
|
||||
head_size,
|
||||
head_size_v,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
v_strideH,
|
||||
nk_strideN,
|
||||
nk_strideH,
|
||||
nv_strideN,
|
||||
nv_strideH,
|
||||
is_mla);
|
||||
|
||||
if (num_heads == num_heads_kv) {
|
||||
// MHA
|
||||
decode_attention_kernel_impl<scalar_t, index_t>(
|
||||
decode_attention_kernel_impl<scalar_t, index_t, BLOCK_N>(
|
||||
output.data_ptr<scalar_t>(),
|
||||
attn_logits.data_ptr<float>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
k_buffer.data_ptr<scalar_t>(),
|
||||
v_buffer.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
loc.data_ptr<int64_t>(),
|
||||
(const scalar_t*)k_buffer_data,
|
||||
(const scalar_t*)v_buffer_data,
|
||||
req_to_token.data_ptr<index_t>(),
|
||||
req_pool_indices.data_ptr<int64_t>(),
|
||||
seq_lens.data_ptr<int64_t>(),
|
||||
@@ -1159,26 +1501,46 @@ void decode_attention_cpu(
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
v_strideH,
|
||||
nk_strideN,
|
||||
nv_strideH,
|
||||
nv_strideN,
|
||||
nv_strideH,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
max_num_reqs,
|
||||
max_context_len,
|
||||
max_total_num_tokens);
|
||||
} else {
|
||||
// GQA/MQA/MLA
|
||||
decode_attention_grouped_kernel_impl<scalar_t, index_t>(
|
||||
} else if (is_mla) {
|
||||
// MLA
|
||||
decode_attention_mla_kernel_impl<scalar_t, index_t, BLOCK_N>(
|
||||
output.data_ptr<scalar_t>(),
|
||||
attn_logits.data_ptr<float>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
k_buffer.data_ptr<scalar_t>(),
|
||||
v_buffer.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
loc.data_ptr<int64_t>(),
|
||||
(const scalar_t*)k_buffer_data,
|
||||
(const scalar_t*)v_buffer_data,
|
||||
req_to_token.data_ptr<index_t>(),
|
||||
req_pool_indices.data_ptr<int64_t>(),
|
||||
seq_lens.data_ptr<int64_t>(),
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size_v,
|
||||
num_kv_splits,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
v_strideH,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
max_num_reqs,
|
||||
max_context_len,
|
||||
max_total_num_tokens,
|
||||
size_per_thread);
|
||||
} else {
|
||||
// GQA/MQA
|
||||
decode_attention_grouped_kernel_impl<scalar_t, index_t, BLOCK_N>(
|
||||
output.data_ptr<scalar_t>(),
|
||||
attn_logits.data_ptr<float>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
(const scalar_t*)k_buffer_data,
|
||||
(const scalar_t*)v_buffer_data,
|
||||
req_to_token.data_ptr<index_t>(),
|
||||
req_pool_indices.data_ptr<int64_t>(),
|
||||
seq_lens.data_ptr<int64_t>(),
|
||||
@@ -1192,10 +1554,6 @@ void decode_attention_cpu(
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
v_strideH,
|
||||
nk_strideN,
|
||||
nk_strideH,
|
||||
nv_strideN,
|
||||
nv_strideH,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
max_num_reqs,
|
||||
|
||||
Reference in New Issue
Block a user