diff --git a/sgl-kernel/csrc/cpu/decode.cpp b/sgl-kernel/csrc/cpu/decode.cpp index 899987677..7f55232e8 100644 --- a/sgl-kernel/csrc/cpu/decode.cpp +++ b/sgl-kernel/csrc/cpu/decode.cpp @@ -1,4 +1,5 @@ #include "common.h" +#include "gemm.h" #include "vec.h" namespace { @@ -11,19 +12,144 @@ namespace { // 4. provide amx kernel for index_gemm_kernel_nn when M = 16 // -inline void fill_stub(float* __restrict__ out, float val, int64_t size) { - using Vec = at::vec::Vectorized; - const Vec data_vec(val); - at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); +#if defined(CPU_CAPABILITY_AVX512) +// key: from [N, 32] to [32/2, N, 2] +// val: from [N, 32] to [N/2, 32, 2] +template +inline void pack_vnni_Nx32( + scalar_t* __restrict__ dst0, + scalar_t* __restrict__ dst1, + const scalar_t* __restrict__ src, + const index_t* __restrict__ ind, + int N, + int ld_src, + int ld_dst0, + int ld_dst1, + bool convert_v) { + __m512i vinputs[16]; + int n = 0; + for (; n < N; ++n) { + vinputs[n] = _mm512_loadu_si512(src + ind[n] * ld_src); + } + // padding with zero to avoid uninitialized vectors + for (; n < 16; ++n) { + vinputs[n] = _mm512_set1_epi32(0); + } + + // pack value, skip 64 elems for deepseek + // handle 2 vectors at a time from [2, 32] to [32, 2] + if (convert_v) { + for (int n = 0; n < 16; n += 2) { + __m512i d0, d1; + std::tie(d0, d1) = transpose_2x32_16bit(vinputs[n], vinputs[n + 1]); + _mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2, d0); + _mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2 + 32, d1); + } + } + + // pack key + transpose_16x16_32bit(vinputs); + + const __mmask16 vmask = (1 << N) - 1; + for (int k = 0; k < 16; ++k) { + _mm512_mask_storeu_epi32(dst0 + k * ld_dst0 * 2, vmask, vinputs[k]); + } +} +#endif + +// [NOTE]: MLA vnni format conversion +// +// here we apply same strategy as `FlashMLA`: +// each kv_cache is loaded once and packed twice (L2 cache hit) +// +// * for key: from [N, K/2, 2] to [K/2, N, 2] +// * for value: from [N/2, 2, Kv] to [N/2, Kv, 2] +// +template +void pack_vnni( + scalar_t* __restrict__ dst0, + scalar_t* __restrict__ dst1, + const scalar_t* __restrict__ src, + const index_t* __restrict__ ind, + int N, + int K, + int Kv, + int ld_src, + int ld_dst0, + int ld_dst1) { +#if defined(CPU_CAPABILITY_AVX512) + const int NB = div_up(N, 16); + const int KB = K / 32; // no remainder + const int KBv = Kv / 32; // no remainder + + for (int nb = 0; nb < NB; ++nb) { + for (int kb = 0; kb < KB; ++kb) { + // handle 16x512bits each block + int nb_size = std::min(N - nb * 16, 16); + pack_vnni_Nx32( + /* dst0 */ dst0 + ((kb * 32) >> 1) * ld_dst0 * 2 + nb * 16 * 2, + /* dst1 */ dst1 + ((nb * 16) >> 1) * ld_dst1 * 2 + kb * 32 * 2, + /* src */ src + kb * 32, + /* ind */ ind + nb * 16, + /* N */ nb_size, + /* ld_src */ ld_src, + /* ld_dst0 */ ld_dst0, + /* ld_dst1 */ ld_dst1, + /* cvt_v */ kb < KBv); + } + } +#else + for (int n = 0; n < N; ++n) { + index_t index = ind[n]; + for (int k = 0; k < K / 2; ++k) { + for (int d = 0; d < 2; ++d) { + dst0[k * ld_dst0 * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d]; + } + } + } + // from [N/2, 2, K] to [N/2, K, 2] + for (int n = 0; n < (N >> 1) * 2; n += 2) { + index_t index0 = ind[n + 0]; + index_t index1 = ind[n + 1]; + for (int k = 0; k < Kv; ++k) { + dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index0 * ld_src + k]; + dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 1] = src[index1 * ld_src + k]; + } + } + if (N % 2 != 0) { + index_t index = ind[N - 1]; + for (int k = 0; k < Kv; ++k) { + dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index * ld_src + k]; + dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 1] = 0; + } + } +#endif +} + +template +inline void fill_stub(scalar_t* __restrict__ out, float val, int64_t size) { + using Vec = at::vec::Vectorized; + constexpr int kVecSize = Vec::size(); + const Vec data_vec = Vec(static_cast(val)); + int64_t d = 0; +#pragma GCC unroll 4 + for (; d <= size - kVecSize; d += kVecSize) { + data_vec.store(out + d); + } + if (size - d > 0) { + data_vec.store(out + d, size - d); + } } template inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) { using bVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); const fVec s_fvec = fVec(s); int64_t d = 0; - for (; d <= size - bVec::size(); d += bVec::size()) { +#pragma GCC unroll 4 + for (; d <= size - kVecSize; d += kVecSize) { fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec; fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec; bVec out_bvec = convert_from_float_ext(a_fvec0, a_fvec1); @@ -37,8 +163,10 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, template inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) { using bVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); int64_t d = 0; - for (; d <= size - bVec::size(); d += bVec::size()) { +#pragma GCC unroll 4 + for (; d <= size - kVecSize; d += kVecSize) { bVec out_bvec = bVec::loadu(src + d); out_bvec.store(out + d); } @@ -47,6 +175,26 @@ inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ s } } +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) { + static_assert(BLOCK_N % 32 == 0); + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int COLS = BLOCK_N / 16; + auto store = [&](auto i) { + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + fVec a_fvec0 = fVec::loadu(input + col * 16); + fVec a_fvec1 = fVec::loadu(input + col * 16 + 16); + bVec out_bvec = convert_from_float_ext(a_fvec0, a_fvec1); + out_bvec.store(out + col * 16); + } + }; + Unroll{}(store); +} + // GEMM handles query @ key (indexed) x scale // A : [M, K] // B : [N, K] indexed @@ -619,16 +767,105 @@ void index_gemm_kernel_nn( } } -template -void decode_attention_kernel_impl( - scalar_t* __restrict__ output, - float* __restrict__ attn_logits, - const scalar_t* __restrict__ query, +template +void decode_set_kv_buffer( scalar_t* __restrict__ k_buffer, scalar_t* __restrict__ v_buffer, const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, const int64_t* __restrict__ loc, + int64_t batches, + int64_t num_heads_kv, + int64_t head_size, + int64_t head_size_v, + int64_t k_strideN, + int64_t k_strideH, + int64_t v_strideN, + int64_t v_strideH, + int64_t nk_strideN, + int64_t nk_strideH, + int64_t nv_strideN, + int64_t nv_strideH, + bool is_mla) { + at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, head_kv_id{0}; + data_index_init(begin, bs, batches, head_kv_id, num_heads_kv); + + for (int64_t i = begin; i < end; i++) { + int64_t loc_val = loc[bs]; + scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH; + const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH; + copy_stub(k_buffer_ptr, new_key_ptr, head_size); + if (!is_mla) { + scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH; + const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH; + copy_stub(v_buffer_ptr, new_value_ptr, head_size_v); + } + + // move to the next index + data_index_step(bs, batches, head_kv_id, num_heads_kv); + } + }); +} + +template +void decode_accumulate_kv_splits( + scalar_t* __restrict__ output, + float* __restrict__ attn_logits, + int64_t batches, + int64_t num_heads, + int64_t head_size_v, + int64_t num_kv_splits, + int64_t l_stride1, + int64_t l_stride2) { + using Vec = at::vec::Vectorized; + + // parallel on [batches, num_heads] + at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { + // NB: here we use logits[b][h][0] as acc, since + // for the first kv split (kv_id == 0): + // m_delta = std::exp(-inf) = 0 + // e_logic = std::exp(0) = 1 + // acc = acc * m_delta + tv * e_logic = tv + for (int64_t i = begin; i < end; ++i) { + float* __restrict__ acc = attn_logits + i * l_stride1; + + float s_prime = 0.f; + float m_prime = -std::numeric_limits::infinity(); + + // update acc with from each kv_split + for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { + float* __restrict__ tv = acc + kv_id * l_stride2; + const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; + + float m_i = std::max(tlogic, m_prime); + float m_delta = std::exp(m_prime - m_i); + float e_logic = std::exp(tlogic - m_i); + if (kv_id != 0) { + at::vec::map2( + [m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, + acc, + acc, + tv, + head_size_v); + } + + s_prime = s_prime * m_delta + e_logic; + m_prime = m_i; + } + + copy_stub(output + i * head_size_v, acc, 1 / s_prime, head_size_v); + } + }); +} + +template +void decode_attention_kernel_impl( + scalar_t* __restrict__ output, + float* __restrict__ attn_logits, + const scalar_t* __restrict__ query, + const scalar_t* __restrict__ k_buffer, + const scalar_t* __restrict__ v_buffer, const index_t* __restrict__ req_to_token, const int64_t* __restrict__ req_pool_indices, const int64_t* __restrict__ seq_lens, @@ -641,38 +878,13 @@ void decode_attention_kernel_impl( int64_t k_strideH, int64_t v_strideN, int64_t v_strideH, - int64_t nk_strideN, - int64_t nk_strideH, - int64_t nv_strideN, - int64_t nv_strideH, float scaling, float logit_cap, int64_t max_num_reqs, int64_t max_context_len, int64_t max_total_num_tokens) { - at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { - int64_t bs{0}, head_id{0}; - data_index_init(begin, bs, batches, head_id, num_heads); - - for (int64_t i = begin; i < end; i++) { - int64_t loc_val = loc[bs]; - scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_id * k_strideH; - scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_id * v_strideH; - const scalar_t* new_key_ptr = key + bs * nk_strideN + head_id * nk_strideH; - const scalar_t* new_value_ptr = value + bs * nv_strideN + head_id * nv_strideH; - copy_stub(k_buffer_ptr, new_key_ptr, head_size); - copy_stub(v_buffer_ptr, new_value_ptr, head_size_v); - - // move to the next index - data_index_step(bs, batches, head_id, num_heads); - } - }); - using Vec = at::vec::Vectorized; - // block length for k_buffer and v_buffer - constexpr int64_t BLOCK_N = 256; - // strides const int64_t q_strideM = num_heads * head_size; const int64_t q_strideH = head_size; @@ -785,55 +997,209 @@ void decode_attention_kernel_impl( } }); - // parallel on [batches, num_heads] - at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { - // NB: here we use logits[b][h][0] as acc, since - // for the first kv split (kv_id == 0): - // m_delta = std::exp(-inf) = 0 - // e_logic = std::exp(0) = 1 - // acc = acc * m_delta + tv * e_logic = tv + decode_accumulate_kv_splits( + output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2); +} // MHA + +template +void decode_attention_mla_kernel_impl( + scalar_t* __restrict__ output, + float* __restrict__ attn_logits, + const scalar_t* __restrict__ query, + const scalar_t* __restrict__ k_buffer, + const scalar_t* __restrict__ v_buffer, + const index_t* __restrict__ req_to_token, + const int64_t* __restrict__ req_pool_indices, + const int64_t* __restrict__ seq_lens, + scalar_t* __restrict__ buffer, + int64_t batches, + int64_t num_heads, + int64_t head_size, + int64_t head_size_v, + int64_t num_kv_splits, + int64_t k_strideN, + int64_t k_strideH, + int64_t v_strideN, + int64_t v_strideH, + float scaling, + float logit_cap, + int64_t max_num_reqs, + int64_t max_context_len, + int64_t max_total_num_tokens, + int64_t buffer_size_per_thread) { + using Vec = at::vec::Vectorized; + + // block length for heads + const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11); + + // strides + const int64_t q_strideM = num_heads * head_size; + const int64_t q_strideH = head_size; + const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1); + const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); + const int64_t l_stride2 = head_size_v + 1; + + TORCH_CHECK(logit_cap == 0.f, "decode MLA: expect no logit_cap."); + + // partition the heads into blocks for parallel + const int64_t num_blocks = div_up(num_heads, BLOCK_H); + + // parallel on [batches, num_blocks, num_kv_splits] + at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, block_id{0}, kv_id{0}; + data_index_init(begin, bs, batches, block_id, num_blocks, kv_id, num_kv_splits); + + int tid = at::get_thread_num(); + scalar_t* __restrict__ Btmp0 = buffer + tid * buffer_size_per_thread; + scalar_t* __restrict__ Btmp1 = Btmp0 + BLOCK_N * head_size; + + // init Btmp1 just once for each thread to prevent NaN + // Btmp0 is not needed as it computes full K every single time + fill_stub(Btmp1, 0.f, BLOCK_N * head_size_v); + + alignas(64) float s_i[BLOCK_H * BLOCK_N]; + float* __restrict__ s_delta = s_i; + alignas(64) scalar_t s_delta2[BLOCK_H * BLOCK_N]; + + alignas(64) float s_prime[BLOCK_H]; + alignas(64) float m_prime[BLOCK_H]; + alignas(64) float m_delta[BLOCK_H]; + for (int64_t i = begin; i < end; ++i) { - float* __restrict__ acc = attn_logits + i * l_stride1; + const int64_t h_start = block_id * BLOCK_H; + const int64_t h_end = std::min(block_id * BLOCK_H + BLOCK_H, num_heads); + const int64_t h_size = h_end - h_start; - float s_prime = 0.f; - float m_prime = -std::numeric_limits::infinity(); + // get query + const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH; - // update acc with from each kv_split - for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { - float* __restrict__ tv = acc + kv_id * l_stride2; - const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; + int64_t seq_len_kv = seq_lens[bs]; + int64_t req_pool_id = req_pool_indices[bs]; + TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!"); + TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!"); - float m_i = std::max(tlogic, m_prime); - float m_delta = std::exp(m_prime - m_i); - float e_logic = std::exp(tlogic - m_i); - if (kv_id != 0) { - at::vec::map2( - [m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, - acc, - acc, - tv, - head_size_v); - } + const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits); + const int64_t kv_start = kv_id * SPLIT_SIZE; + const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv); - s_prime = s_prime * m_delta + e_logic; - m_prime = m_i; + fill_stub(s_prime, 0.f, BLOCK_H); + fill_stub(m_prime, -std::numeric_limits::infinity(), BLOCK_H); + + // get v_prime, and init to zero + float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2; + for (int64_t h = 0; h < h_size; ++h) { + fill_stub(v_prime + h * l_stride1, 0.f, head_size_v); } - copy_stub(output + i * head_size_v, acc, 1 / s_prime, head_size_v); - } - }); -} + // loop over K and V sequence with BLOCK_N + for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) { + int64_t n_size = std::min(BLOCK_N, kv_end - n); + const int64_t padded_n_size = div_up(int(n_size), TILE_K) * TILE_K; -template + // get key and pack + pack_vnni( + /* dst0 */ Btmp0, + /* dst1 */ Btmp1, + /* src */ k_buffer + /* head_kv_id */ 0 * k_strideH, + /* ind */ req_to_token + req_pool_id * max_context_len + n, + /* N */ n_size, + /* K */ head_size, + /* Kv */ head_size_v, + /* ld_src */ k_strideN, + /* ld_dst0 */ BLOCK_N, + /* ld_dst1 */ head_size_v); + + // calculate s_i <- Q @ K + at::native::cpublas::brgemm( + /* M */ h_size, + /* N */ n_size, + /* K */ head_size, + /* lda */ q_strideH, + /* ldb */ BLOCK_N, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ q_ptr, + /* B */ Btmp0, + /* C */ s_i); + + const Vec scale_vec = Vec(scaling); + for (int64_t h = 0; h < h_size; ++h) { + // s_i <- s_i * scale + at::vec::map( + [scale_vec](Vec x) { return x * scale_vec; }, s_i + h * BLOCK_N, s_i + h * BLOCK_N, n_size); + + // m_i: max value per row + float m_i = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size); + m_i = std::max(m_i, m_prime[h]); + + // m_delta <- exp(m' - m_i) + m_delta[h] = std::exp(m_prime[h] - m_i); + + // s_delta <- exp(s_i - m_i) + at::vec::map( + [m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size); + + // s' <- s' * m_delta + sum(s_delta) + s_prime[h] *= m_delta[h]; + s_prime[h] += at::vec::reduce_all([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size); + + m_prime[h] = m_i; + + // v' <- v' * m_delta + float scale_m = m_delta[h]; + at::vec::map( + [scale_m](Vec x) { return x * Vec(scale_m); }, + v_prime + h * l_stride1, + v_prime + h * l_stride1, + head_size_v); + + // pad s_delta with 0 first and then convert to scalar_t + fill_stub(s_delta + h * BLOCK_N + n_size, 0.f, padded_n_size - n_size); + copy_stub(s_delta2 + h * BLOCK_N, s_delta + h * BLOCK_N); + } + + // calculate V' <- s_delta @ V + V' + at::native::cpublas::brgemm( + /* M */ h_size, + /* N */ head_size_v, + /* K */ padded_n_size, // n_size + /* lda */ BLOCK_N, + /* ldb */ head_size_v, + /* ldc */ l_stride1, + /* add_C */ true, + /* A */ s_delta2, + /* B */ Btmp1, + /* C */ v_prime); + } // loop with KV blocks + + // only update v' when kv_split_size > 0 + if (kv_end > kv_start) { + for (int64_t h = 0; h < h_size; ++h) { + float s = 1 / s_prime[h]; + at::vec::map( + [s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v); + (v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]); + } + } + + // move to the next index + data_index_step(bs, batches, block_id, num_blocks, kv_id, num_kv_splits); + } + at::native::cpublas::brgemm_release(); + }); + + decode_accumulate_kv_splits( + output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2); +} // MLA + +template void decode_attention_grouped_kernel_impl( scalar_t* __restrict__ output, float* __restrict__ attn_logits, const scalar_t* __restrict__ query, - scalar_t* __restrict__ k_buffer, - scalar_t* __restrict__ v_buffer, - const scalar_t* __restrict__ key, - const scalar_t* __restrict__ value, - const int64_t* __restrict__ loc, + const scalar_t* __restrict__ k_buffer, + const scalar_t* __restrict__ v_buffer, const index_t* __restrict__ req_to_token, const int64_t* __restrict__ req_pool_indices, const int64_t* __restrict__ seq_lens, @@ -847,37 +1213,13 @@ void decode_attention_grouped_kernel_impl( int64_t k_strideH, int64_t v_strideN, int64_t v_strideH, - int64_t nk_strideN, - int64_t nk_strideH, - int64_t nv_strideN, - int64_t nv_strideH, float scaling, float logit_cap, int64_t max_num_reqs, int64_t max_context_len, int64_t max_total_num_tokens) { - at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) { - int64_t bs{0}, head_kv_id{0}; - data_index_init(begin, bs, batches, head_kv_id, num_heads_kv); - - for (int64_t i = begin; i < end; i++) { - int64_t loc_val = loc[bs]; - scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH; - scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH; - const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH; - const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH; - copy_stub(k_buffer_ptr, new_key_ptr, head_size); - copy_stub(v_buffer_ptr, new_value_ptr, head_size_v); - - // move to the next index - data_index_step(bs, batches, head_kv_id, num_heads_kv); - } - }); - using Vec = at::vec::Vectorized; - // block length for k_buffer and v_buffer - constexpr int64_t BLOCK_N = 256; // block length for heads // we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits] // use smaller BLOCK_H when batches is small to utilize all cores @@ -960,7 +1302,7 @@ void decode_attention_grouped_kernel_impl( [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, s_i, s_i, - n_size); + BLOCK_H * BLOCK_N); } // update the scaling coefficients @@ -1015,40 +1357,9 @@ void decode_attention_grouped_kernel_impl( } }); - // parallel on [batches, num_heads] - at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { - // NB: same as above - for (int64_t i = begin; i < end; ++i) { - float* __restrict__ acc = attn_logits + i * l_stride1; - - float s_prime = 0.f; - float m_prime = -std::numeric_limits::infinity(); - - // update acc with from each kv_split - for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { - float* __restrict__ tv = acc + kv_id * l_stride2; - const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; - - float m_i = std::max(tlogic, m_prime); - float m_delta = std::exp(m_prime - m_i); - float e_logic = std::exp(tlogic - m_i); - if (kv_id != 0) { - at::vec::map2( - [m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, - acc, - acc, - tv, - head_size_v); - } - - s_prime = s_prime * m_delta + e_logic; - m_prime = m_i; - } - - copy_stub(output + i * head_size_v, acc, 1 / s_prime, head_size_v); - } - }); -} + decode_accumulate_kv_splits( + output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2); +} // GQA/MQA } // anonymous namespace @@ -1134,19 +1445,50 @@ void decode_attention_cpu( "decode: expect req_pool_indices to be int64, got ", req_pool_indices.scalar_type()); + // check if we have MLA here + void* k_buffer_data = k_buffer.data_ptr(); + void* v_buffer_data = v_buffer.data_ptr(); + const bool is_mla = (k_buffer_data == v_buffer_data) && (num_heads_kv == 1) && (head_size == head_size_v + 64); + + // block length for k_buffer and v_buffer + constexpr int BLOCK_N = 256; + + // buffer for packing k_cache and v_cache + int num_threads = at::get_num_threads(); + int64_t size_per_thread = is_mla ? BLOCK_N * head_size + BLOCK_N * head_size_v : 0; + auto buffer = at::empty({num_threads, size_per_thread}, k_buffer.options()); + AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] { AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] { + // update the kv buffer + decode_set_kv_buffer( + (scalar_t*)k_buffer_data, + (scalar_t*)v_buffer_data, + key.data_ptr(), + value.data_ptr(), + loc.data_ptr(), + num_seqs, + num_heads_kv, + head_size, + head_size_v, + k_strideN, + k_strideH, + v_strideN, + v_strideH, + nk_strideN, + nk_strideH, + nv_strideN, + nv_strideH, + is_mla); + if (num_heads == num_heads_kv) { // MHA - decode_attention_kernel_impl( + decode_attention_kernel_impl( output.data_ptr(), attn_logits.data_ptr(), query.data_ptr(), - k_buffer.data_ptr(), - v_buffer.data_ptr(), - key.data_ptr(), - value.data_ptr(), - loc.data_ptr(), + (const scalar_t*)k_buffer_data, + (const scalar_t*)v_buffer_data, req_to_token.data_ptr(), req_pool_indices.data_ptr(), seq_lens.data_ptr(), @@ -1159,26 +1501,46 @@ void decode_attention_cpu( k_strideH, v_strideN, v_strideH, - nk_strideN, - nv_strideH, - nv_strideN, - nv_strideH, sm_scale, logit_cap, max_num_reqs, max_context_len, max_total_num_tokens); - } else { - // GQA/MQA/MLA - decode_attention_grouped_kernel_impl( + } else if (is_mla) { + // MLA + decode_attention_mla_kernel_impl( output.data_ptr(), attn_logits.data_ptr(), query.data_ptr(), - k_buffer.data_ptr(), - v_buffer.data_ptr(), - key.data_ptr(), - value.data_ptr(), - loc.data_ptr(), + (const scalar_t*)k_buffer_data, + (const scalar_t*)v_buffer_data, + req_to_token.data_ptr(), + req_pool_indices.data_ptr(), + seq_lens.data_ptr(), + buffer.data_ptr(), + num_seqs, + num_heads, + head_size, + head_size_v, + num_kv_splits, + k_strideN, + k_strideH, + v_strideN, + v_strideH, + sm_scale, + logit_cap, + max_num_reqs, + max_context_len, + max_total_num_tokens, + size_per_thread); + } else { + // GQA/MQA + decode_attention_grouped_kernel_impl( + output.data_ptr(), + attn_logits.data_ptr(), + query.data_ptr(), + (const scalar_t*)k_buffer_data, + (const scalar_t*)v_buffer_data, req_to_token.data_ptr(), req_pool_indices.data_ptr(), seq_lens.data_ptr(), @@ -1192,10 +1554,6 @@ void decode_attention_cpu( k_strideH, v_strideN, v_strideH, - nk_strideN, - nk_strideH, - nv_strideN, - nv_strideH, sm_scale, logit_cap, max_num_reqs, diff --git a/sgl-kernel/csrc/cpu/extend.cpp b/sgl-kernel/csrc/cpu/extend.cpp index 9ae36574f..c9f424634 100644 --- a/sgl-kernel/csrc/cpu/extend.cpp +++ b/sgl-kernel/csrc/cpu/extend.cpp @@ -10,11 +10,72 @@ namespace { // 3. computes attention for prefix and extend separately // 4. TODO: vectorize `pack_vnni` and `pack_vnni2` // + template inline index_t get_index(index_t* ind, int i) { return (ind == nullptr) ? (index_t)i : ind[i]; } +#if defined(CPU_CAPABILITY_AVX512) +// key: from [N, 32] to [32/2, N, 2] +template +inline void pack_vnni_Nx32( + scalar_t* __restrict__ dst, + const scalar_t* __restrict__ src, + const index_t* __restrict__ ind, + int N, + int ld_src, + int ld_dst) { + __m512i vinputs[16]; + + int n = 0; + for (; n < N; ++n) { + index_t index = get_index(ind, n); + vinputs[n] = _mm512_loadu_si512(src + index * ld_src); + } + // padding with zero to avoid uninitialized vectors + for (; n < 16; ++n) { + vinputs[n] = _mm512_set1_epi32(0); + } + + // pack key + transpose_16x16_32bit(vinputs); + + const __mmask16 vmask = (1 << N) - 1; + for (int k = 0; k < 16; ++k) { + _mm512_mask_storeu_epi32(dst + k * ld_dst * 2, vmask, vinputs[k]); + } +} + +// value: from [K, 32] to [K/2, 32, 2] +template +inline void pack_vnni_Kx32( + scalar_t* __restrict__ dst, + const scalar_t* __restrict__ src, + const index_t* __restrict__ ind, + int K, + int ld_src, + int ld_dst) { + __m512i vinputs[2]; + + int k = 0; + for (; k < K; ++k) { + index_t index = get_index(ind, k); + vinputs[k] = _mm512_loadu_si512(src + index * ld_src); + } + // padding with zero to avoid uninitialized vectors + for (; k < 2; ++k) { + vinputs[k] = _mm512_set1_epi32(0); + } + + // pack value + __m512i d0, d1; + std::tie(d0, d1) = transpose_2x32_16bit(vinputs[0], vinputs[1]); + _mm512_storeu_si512(dst + 0 * ld_dst * 2, d0); + _mm512_storeu_si512(dst + 0 * ld_dst * 2 + 32, d1); +} +#endif + // convert to vnni format // from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16 template @@ -26,6 +87,25 @@ void pack_vnni( int K, int ld_src, int ld_dst) { +#if defined(CPU_CAPABILITY_AVX512) + const int NB = div_up(N, 16); + const int KB = K / 32; // no remainder + const bool is_indexed = ind != nullptr; + + for (int nb = 0; nb < NB; ++nb) { + for (int kb = 0; kb < KB; ++kb) { + // handle 16x512bits each block + int nb_size = std::min(N - nb * 16, 16); + pack_vnni_Nx32( + /* dst */ dst + ((kb * 32) >> 1) * ld_dst * 2 + nb * 16 * 2, + /* src */ src + kb * 32 + (is_indexed ? 0 : nb * 16 * ld_src), + /* ind */ is_indexed ? ind + nb * 16 : nullptr, + /* N */ nb_size, + /* ld_src */ ld_src, + /* ld_dst */ ld_dst); + } + } +#else for (int n = 0; n < N; ++n) { index_t index = get_index(ind, n); for (int k = 0; k < K / 2; ++k) { @@ -34,6 +114,7 @@ void pack_vnni( } } } +#endif } // convert to vnni format @@ -47,6 +128,25 @@ void pack_vnni2( int N, int ld_src, int ld_dst) { +#if defined(CPU_CAPABILITY_AVX512) + const int KB = div_up(K, 2); + const int NB = N / 32; // no remainder + const bool is_indexed = ind != nullptr; + + for (int kb = 0; kb < KB; ++kb) { + for (int nb = 0; nb < NB; ++nb) { + // handle 2x512bits each block + int kb_size = std::min(K - kb * 2, 2); + pack_vnni_Kx32( + /* dst */ dst + ((kb * 2) >> 1) * ld_dst * 2 + nb * 32 * 2, + /* src */ src + (is_indexed ? 0 : kb * 2 * ld_src) + nb * 32, + /* ind */ is_indexed ? ind + kb * 2 : nullptr, + /* K */ kb_size, + /* ld_src */ ld_src, + /* ld_dst */ ld_dst); + } + } +#else int k = 0; for (; k < (K >> 1) * 2; k += 2) { index_t index0 = get_index(ind, k + 0); @@ -64,21 +164,17 @@ void pack_vnni2( } k += 2; } - // TODO: check whether we can skip this! - // const int padded_K = div_up(K, TILE_K) * TILE_K; - // for (; k < padded_K; ++k) { - // for (int n = 0; n < N; ++n) { - // dst[k * ld_dst + n] = static_cast(0); - // } - // } +#endif } template inline void fill_stub(scalar_t* __restrict__ out, float val, int size) { using Vec = at::vec::Vectorized; + constexpr int kVecSize = Vec::size(); const Vec data_vec = Vec(static_cast(val)); int d = 0; - for (; d <= size - Vec::size(); d += Vec::size()) { +#pragma GCC unroll 4 + for (; d <= size - kVecSize; d += kVecSize) { data_vec.store(out + d); } if (size - d > 0) { @@ -110,9 +206,11 @@ template inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) { using bVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); const fVec s_fvec = fVec(s); int d = 0; - for (; d <= size - bVec::size(); d += bVec::size()) { +#pragma GCC unroll 4 + for (; d <= size - kVecSize; d += kVecSize) { fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec; fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec; bVec out_bvec = convert_from_float_ext(a_fvec0, a_fvec1); diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index 901939f11..eabbfb7c8 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -93,6 +93,8 @@ void fused_experts_fp8_kernel_impl( scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic2, scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, @@ -135,6 +137,8 @@ void shared_expert_fp8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic0, scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 1feded107..3bba40786 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -2,6 +2,9 @@ #include "gemm.h" #include "vec.h" +// we use 4x32 for BLOCK_M +#define BLOCK_SIZE_M_SCALE 4 + namespace { template @@ -61,33 +64,38 @@ inline void unpack_B( constexpr int BLOCK_N = block_size_n(); static_assert(BLOCK_N == 32); + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + +#pragma GCC unroll 4 for (int k = 0; k < K2; ++k) { - for (int n = 0; n < N; n += 64) { // BLOCK_N = 32 - __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + n); - - __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); - __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); - - __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); - __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); - - // Apply scale - __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); - __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); - __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); - __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); - - f0_lo = _mm512_mul_ps(f0_lo, vd); - f0_hi = _mm512_mul_ps(f0_hi, vd); - f1_lo = _mm512_mul_ps(f1_lo, vd); - f1_hi = _mm512_mul_ps(f1_hi, vd); - - bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); - bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); - - _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 0, (__m512i)bf16_0); - _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 32, (__m512i)bf16_1); + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); } + + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); + + bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); } #else TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); @@ -128,24 +136,30 @@ struct tinygemm_kernel_nn{}(loadc); - const int K2 = K >> 1; const int lda2 = lda >> 1; const int ldb2 = ldb; // ldb * 2 >> 1; const float* a_ptr = reinterpret_cast(A); @@ -155,11 +169,11 @@ struct tinygemm_kernel_nn 0) { + _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } } if constexpr (row == 0) { if constexpr (col % 2 == 0) { @@ -167,47 +181,40 @@ struct tinygemm_kernel_nn 0) { _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); } - - __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); - __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); - - __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); - __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); - - // Apply scale - __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); - __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); - __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); - __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); - - f0_lo = _mm512_mul_ps(f0_lo, vd); - f0_hi = _mm512_mul_ps(f0_hi, vd); - f1_lo = _mm512_mul_ps(f1_lo, vd); - f1_hi = _mm512_mul_ps(f1_hi, vd); - - vb[col + 0] = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); - vb[col + 1] = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); + vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); } } - vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); }; - for (int k = 0; k < K2; ++k) { - Unroll{}(compute, k); + + constexpr int BLOCK_K2 = BLOCK_K >> 1; + for (int kb = 0; kb < KB; ++kb) { + int kb_start = kb * BLOCK_K2; + int kb_end = std::min(K >> 1, kb_start + BLOCK_K2); + // 1. load scale vector + vscale = _mm512_set1_ps(scale[kb]); + if constexpr (PREFETCH_SIZE_KB > 0) { + _mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0); + } + // 2. zero vsum for each block + Unroll{}([&](auto i) { vsum[i] = _mm512_setzero_ps(); }); + // 3. accumulate across each block + for (int k = kb_start; k < kb_end; ++k) { + Unroll{}(compute, k); + } + // 4. apply scale + Unroll{}([&](auto i) { vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); }); } auto storec = [&](auto i) { constexpr int row = i / COLS; constexpr int col = i % COLS; - // for COLS = 1, 3 use 256bit store - // for COLS = 2, 4 use 512bit store - if constexpr (COLS % 2 == 0) { - if constexpr (col % 2 == 0) { - _mm512_storeu_si512( - reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), - (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); - } - } else { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); } }; Unroll{}(storec); @@ -266,22 +273,18 @@ struct brgemm { int ldc) { constexpr int BLOCK_N = block_size_n(); - // [BLOCK_K, BLOCK_N] -> [BLOCK_K / 2, BLOCK_N * 2] - const int ldb_tmp = block_size_n(); + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; - static_assert(BLOCK_K == 128); - - // accumulate across K per BLOCK_K for (int k = 0; k < K; k += BLOCK_K) { int kb_size = std::min(BLOCK_K, K - k); int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 - unpack_B(Btmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); - - const bool add_C = (k != 0); - at::native::cpublas::brgemm(M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp); + unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); } + at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + // copy from Ctmp to C for (int m = 0; m < M; ++m) { if constexpr (has_bias) { @@ -328,34 +331,18 @@ void tinygemm_kernel( int64_t nb_size = std::min(BLOCK_N, N - nb_start); switch (mb_size << 4 | nb_size >> 4) { - // mb_size = 1 case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; - case 0x14: - LAUNCH_TINYGEMM_KERNEL_NN(1, 64); - break; - // mb_size = 2 case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; - case 0x24: - LAUNCH_TINYGEMM_KERNEL_NN(2, 64); - break; - // mb_size = 3 case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; - case 0x34: - LAUNCH_TINYGEMM_KERNEL_NN(3, 64); - break; - // mb_size = 4 case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; - case 0x44: - LAUNCH_TINYGEMM_KERNEL_NN(4, 64); - break; default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); } @@ -370,14 +357,16 @@ void fp8_scaled_mm_kernel_impl( const at::Float8_e4m3fn* __restrict__ mat2, const float* __restrict__ scales2, const float* __restrict__ bias, + scalar_t* __restrict__ buffer, int64_t M, int64_t N, int64_t K, int64_t mat1_strideM, int64_t out_strideM, int64_t block_size_N, - int64_t block_size_K) { - constexpr int64_t BLOCK_M = block_size_m(); + int64_t block_size_K, + int64_t buffer_size_per_thread) { + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; constexpr int64_t BLOCK_N = block_size_n(); const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); @@ -393,10 +382,9 @@ void fp8_scaled_mm_kernel_impl( int64_t mb{0}, nb{0}; data_index_init(begin, mb, MB, nb, NB); - // for brgemm, use float32 for accumulate - alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - // for brgemm when mat2 is float8_e4m3 - alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; + int tid = at::get_thread_num(); + scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; + float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); for (int64_t i = begin; i < end; ++i) { UNUSED(i); @@ -507,6 +495,7 @@ at::Tensor fp8_scaled_mm_cpu( int64_t block_size_N = block_size[0]; int64_t block_size_K = block_size[1]; + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; constexpr int64_t BLOCK_N = block_size_n(); TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); @@ -531,6 +520,12 @@ at::Tensor fp8_scaled_mm_cpu( bias_data = bias.value().data_ptr(); } + // Btmp : [T, BLOCK_N * K] + // Ctmp : [T, BLOCK_M * BLOCK_N] + int num_threads = at::get_num_threads(); + int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { fp8_scaled_mm_kernel_impl( out.data_ptr(), @@ -538,13 +533,15 @@ at::Tensor fp8_scaled_mm_cpu( packed_w.data_ptr(), scales2.data_ptr(), bias_data, + buffer.data_ptr(), M, N, K, mat1_strideM, out_strideM, block_size_N, - block_size_K); + block_size_K, + size_per_thread); }); return out; diff --git a/sgl-kernel/csrc/cpu/interface.cpp b/sgl-kernel/csrc/cpu/interface.cpp index 61d9686d6..969a6bad4 100644 --- a/sgl-kernel/csrc/cpu/interface.cpp +++ b/sgl-kernel/csrc/cpu/interface.cpp @@ -33,11 +33,11 @@ void initialize(int64_t size, int64_t rank) { world_rank = rank; is_initialized = true; - auto addr_string = std::getenv("MASTER_ADDR"); + const char* addr_string = std::getenv("MASTER_ADDR"); if (addr_string == NULL) { addr_string = ""; } - auto port_string = std::getenv("MASTER_PORT"); + const char* port_string = std::getenv("MASTER_PORT"); if (port_string == NULL) { port_string = ""; } diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index ea6b0cc2c..2a7d163bb 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1080,7 +1080,8 @@ at::Tensor fused_experts_cpu( // 6. As_tmp : [M * topk] // // for fp8 w8a16: - // 7. intermediate_cache1 : [M * topk, 2N] + // 7. intermediate_cache0 : [M * topk, 2N] + // 8. B_tmp : [T, BLOCK_N, std::max(K, N)] // int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + @@ -1090,7 +1091,7 @@ at::Tensor fused_experts_cpu( buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); } if (use_fp8_w8a16) { - buffer_size_nbytes += M * topk * 2 * N * 2; + buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; } auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); @@ -1136,7 +1137,9 @@ at::Tensor fused_experts_cpu( } else if (use_fp8_w8a16) { // here we just ignore C_tmp as it is not used scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); - scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N)); CHECK_MOE_SCALES_FP8(1, 2); fused_experts_fp8_kernel_impl( @@ -1145,6 +1148,8 @@ at::Tensor fused_experts_cpu( intermediate_cache1, intermediate_cache2, A_tmp, + B_tmp, + C_tmp, hidden_states.data_ptr(), packed_w1.data_ptr(), packed_w2.data_ptr(), @@ -1258,6 +1263,7 @@ at::Tensor shared_expert_cpu( // // for fp8 w8a16: // 5. intermediate_cache0 : [M, 2N] + // 6. B_tmp: [T, BLOCK_M, max(K, N)] // int num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); @@ -1266,7 +1272,7 @@ at::Tensor shared_expert_cpu( buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); } if (use_fp8_w8a16) { - buffer_size_nbytes += M * 2 * N * 2; + buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; } auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); @@ -1301,12 +1307,15 @@ at::Tensor shared_expert_cpu( K); } else if (use_fp8_w8a16) { scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N)); CHECK_MOE_SCALES_FP8(0, 1); shared_expert_fp8_kernel_impl( out_hidden_states.data_ptr(), intermediate_cache0, intermediate_cache1, + B_tmp, + C_tmp, hidden_states.data_ptr(), packed_w1.data_ptr(), packed_w2.data_ptr(), diff --git a/sgl-kernel/csrc/cpu/moe_fp8.cpp b/sgl-kernel/csrc/cpu/moe_fp8.cpp index 3aaddacf2..cb891fca2 100644 --- a/sgl-kernel/csrc/cpu/moe_fp8.cpp +++ b/sgl-kernel/csrc/cpu/moe_fp8.cpp @@ -142,6 +142,8 @@ void fused_experts_fp8_kernel_impl( scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic2, scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, @@ -178,9 +180,6 @@ void fused_experts_fp8_kernel_impl( int tid = at::get_thread_num(); scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; - alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - bool is_brgemm_used = false; for (int64_t i = begin; i < end; ++i) { @@ -212,8 +211,8 @@ void fused_experts_fp8_kernel_impl( /* A */ A, /* B */ B, /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ Bs, /* M */ m_size, /* N */ n_size, @@ -250,9 +249,8 @@ void fused_experts_fp8_kernel_impl( // parallel on [MB2, NB2] at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { - alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; + int tid = at::get_thread_num(); alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_K]; bool is_brgemm_used = false; @@ -281,8 +279,8 @@ void fused_experts_fp8_kernel_impl( /* A */ A, /* B */ B, /* C */ C, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ Bs, /* M */ m_size, /* N */ n_size, @@ -323,6 +321,8 @@ void fused_experts_fp8_kernel_impl( TYPE* __restrict__ ic1, \ TYPE* __restrict__ ic2, \ TYPE* __restrict__ A_tmp, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ const TYPE* __restrict__ input, \ const at::Float8_e4m3fn* __restrict__ packed_w1, \ const at::Float8_e4m3fn* __restrict__ packed_w2, \ @@ -349,6 +349,8 @@ void shared_expert_fp8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic0, scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, @@ -373,8 +375,7 @@ void shared_expert_fp8_kernel_impl( const bool use_brgemm = can_use_brgemm(M); at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + int tid = at::get_thread_num(); for (int64_t i = begin; i < end; ++i) { int64_t mb = i / NB; @@ -386,8 +387,8 @@ void shared_expert_fp8_kernel_impl( /* A */ input + mb * BLOCK_M * K, /* B */ packed_w1 + nb * BLOCK_N * K, /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, /* M */ m_size, /* N */ n_size, @@ -421,9 +422,8 @@ void shared_expert_fp8_kernel_impl( // parallel on [MB2, NB2] at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { - alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; + int tid = at::get_thread_num(); alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_K]; for (int64_t i = begin; i < end; ++i) { int64_t mb = i / NB2; @@ -436,8 +436,8 @@ void shared_expert_fp8_kernel_impl( /* A */ ic1 + mb * BLOCK_M * N, /* B */ packed_w2 + nb * BLOCK_N * N, /* C */ C, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, /* M */ m_size, /* N */ n_size, @@ -467,6 +467,8 @@ void shared_expert_fp8_kernel_impl( TYPE* __restrict__ output, \ TYPE* __restrict__ ic0, \ TYPE* __restrict__ ic1, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ const TYPE* __restrict__ input, \ const at::Float8_e4m3fn* __restrict__ packed_w1, \ const at::Float8_e4m3fn* __restrict__ packed_w2, \ diff --git a/sgl-kernel/csrc/cpu/norm.cpp b/sgl-kernel/csrc/cpu/norm.cpp index 88faafd5b..2c4e1f38d 100644 --- a/sgl-kernel/csrc/cpu/norm.cpp +++ b/sgl-kernel/csrc/cpu/norm.cpp @@ -72,6 +72,7 @@ void rmsnorm_kernel_impl( const scalar_t* __restrict__ weight, int64_t batch_size, int64_t hidden_size, + int64_t input_strideN, float eps = 1e-5) { using bVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; @@ -81,7 +82,7 @@ void rmsnorm_kernel_impl( for (int64_t i = begin; i < end; ++i) { // local ptrs scalar_t* __restrict__ out_ptr = output + i * hidden_size; - const scalar_t* __restrict__ input_ptr = input + i * hidden_size; + const scalar_t* __restrict__ input_ptr = input + i * input_strideN; fVec sum_fvec = fVec(float(0)); float sum_val = float(0); @@ -140,6 +141,7 @@ void fused_add_rmsnorm_kernel_impl( float* __restrict__ buffer, int64_t batch_size, int64_t hidden_size, + int64_t input_strideN, float eps = 1e-5) { using bVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; @@ -151,7 +153,7 @@ void fused_add_rmsnorm_kernel_impl( for (int64_t i = begin; i < end; ++i) { // local ptrs - scalar_t* __restrict__ input_ptr = input + i * hidden_size; + scalar_t* __restrict__ input_ptr = input + i * input_strideN; scalar_t* __restrict__ residual_ptr = residual + i * hidden_size; fVec sum_fvec = fVec(float(0)); @@ -242,7 +244,7 @@ at::Tensor l2norm_cpu(at::Tensor& input, double eps) { at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector({input, weight})); - CHECK_INPUT(input); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_INPUT(weight); CHECK_DIM(2, input); CHECK_DIM(1, weight); @@ -250,6 +252,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { int64_t batch_size = input.size(0); int64_t hidden_size = input.size(1); at::Tensor output = at::empty_like(input); + int64_t input_strideN = input.stride(0); AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] { rmsnorm_kernel_impl( @@ -258,6 +261,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { weight.data_ptr(), batch_size, hidden_size, + input_strideN, eps); }); return output; @@ -268,7 +272,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { // weight : {hidden_size} void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) { RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector({input, residual, weight})); - CHECK_INPUT(input); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_INPUT(residual); CHECK_INPUT(weight); CHECK_DIM(2, input); @@ -279,6 +283,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& CHECK_EQ(input.size(1), weight.size(0)); int64_t batch_size = input.size(0); int64_t hidden_size = input.size(1); + int64_t input_strideN = input.stride(0); // allocate temp buffer to store x in float32 per thread // TODO: implement a singleton for context @@ -293,6 +298,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& buffer.data_ptr(), batch_size, hidden_size, + input_strideN, eps); }); } diff --git a/sgl-kernel/csrc/cpu/qkv_proj.cpp b/sgl-kernel/csrc/cpu/qkv_proj.cpp index 82c4d6583..8d663e84a 100644 --- a/sgl-kernel/csrc/cpu/qkv_proj.cpp +++ b/sgl-kernel/csrc/cpu/qkv_proj.cpp @@ -162,6 +162,7 @@ void segment_gemm_kernel_impl( const at::Float8_e4m3fn* __restrict__ B1, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + scalar_t* __restrict__ Btmp, int64_t M, int64_t N0, int64_t N1, @@ -185,10 +186,9 @@ void segment_gemm_kernel_impl( int64_t mb{0}, nb{0}; data_index_init(begin, mb, MB, nb, NB); + int tid = at::get_thread_num(); // for brgemm, use float32 for accumulate alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - // for brgemm when mat2 is float8_e4m3 - alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; for (int64_t i = begin; i < end; ++i) { UNUSED(i); @@ -209,7 +209,7 @@ void segment_gemm_kernel_impl( /* A */ A + mb_start * K, /* B */ B + local_nb_start * K /* nb * BLOCK_N * K */, /* C */ C + mb_start * ldc + local_nb_start, - /* Btmp*/ Btmp, + /* Btmp*/ Btmp + tid * BLOCK_N * K, /* Ctmp*/ Ctmp, /* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K, /* M */ mb_size, @@ -541,6 +541,10 @@ std::tuple qkv_proj_with_rope( CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K)); CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N)); CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K)); + + const int BLOCK_N = block_size_n(); + const int num_threads = at::get_num_threads(); + auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options); segment_gemm_kernel_impl( qa.data_ptr(), k_input.data_ptr(), @@ -549,6 +553,7 @@ std::tuple qkv_proj_with_rope( kv_a_proj_weight.data_ptr(), q_a_proj_s.data_ptr(), kv_a_proj_s.data_ptr(), + buffer.data_ptr(), num_seqs, q_lora_rank, kv_lora_rank + qk_rope_head_dim, @@ -624,3 +629,74 @@ std::tuple qkv_proj_with_rope( return std::make_tuple(q_input, k_input, v_input); } + +std::tuple qkv_proj_with_rope_fused_weight( + at::Tensor& hidden_states, + at::Tensor& qkv_a_proj_weight, + at::Tensor& q_b_proj_weight, + at::Tensor& w_kc, + at::Tensor& q_a_layernorm_weight, + at::Tensor& kv_a_layernorm_weight, + at::Tensor& positions, + at::Tensor& cos_sin_cache, + double eps, + bool use_int8_w8a8, + bool use_fp8_w8a16, + std::optional qkv_a_proj_scale, + std::optional q_b_proj_scale, + bool is_vnni, + std::optional> block_size, + int64_t q_lora_rank, + int64_t kv_lora_rank, + int64_t qk_rope_head_dim) { + RECORD_FUNCTION( + "sgl-kernel::qkv_proj_with_rope_fused_weight", + std::vector({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc})); + + int64_t hidden_size = hidden_states.size(1); + CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim); + CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8)); + + std::vector weight_chunks = + at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0); + at::Tensor q_a_proj_weight = weight_chunks[0]; + at::Tensor kv_a_proj_weight = weight_chunks[1]; + at::Tensor q_a_proj_s; + at::Tensor kv_a_proj_s; + + if (use_int8_w8a8) { + TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8."); + std::vector scale_chunks = + at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0); + q_a_proj_s = scale_chunks[0]; + kv_a_proj_s = scale_chunks[1]; + } + if (use_fp8_w8a16) { + TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16."); + int64_t block_size_N = block_size.value()[0]; + int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N); + int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N); + std::vector scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0); + q_a_proj_s = scale_chunks[0]; + kv_a_proj_s = scale_chunks[1]; + } + + return qkv_proj_with_rope( + hidden_states, + q_a_proj_weight, + q_b_proj_weight, + kv_a_proj_weight, + w_kc, + q_a_layernorm_weight, + kv_a_layernorm_weight, + positions, + cos_sin_cache, + eps, + use_int8_w8a8, + use_fp8_w8a16, + q_a_proj_s, + q_b_proj_scale, + kv_a_proj_s, + is_vnni, + block_size); +} diff --git a/sgl-kernel/csrc/cpu/shm.cpp b/sgl-kernel/csrc/cpu/shm.cpp index 9f7d89df1..1bf65a9b2 100644 --- a/sgl-kernel/csrc/cpu/shm.cpp +++ b/sgl-kernel/csrc/cpu/shm.cpp @@ -54,7 +54,8 @@ void shared_open(SharedData* data, const char* name, size_t nbytes) { void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) { int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); if (d != -1) { - if (nbytes = write(d, bytes, nbytes)) { + nbytes = write(d, bytes, nbytes); + if (nbytes > 0) { shared_open(data, name, nbytes); } } else { @@ -391,7 +392,7 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, static bool is_initialized = false; static int world_rank; -void shm_initialize(int size, int rank, char* addr_string, char* port_string) { +void shm_initialize(int size, int rank, const char* addr_string, const char* port_string) { if (is_initialized) { return; } @@ -409,7 +410,7 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) { struct allreduce_workspace* workspace_buf; struct allreduce_workspace* workspace_buf_other; workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); - snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); + snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, rank); shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace)); workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; @@ -425,7 +426,7 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) { // map shm of all ranks for (int i = 0; i < size; i++) { if (i != rank) { - snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); + snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, i); // printf("open %s, %d\n", shm_name, rank); do { shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); @@ -447,13 +448,13 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes) { auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES); // process aligned part #pragma omp parallel for - for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) { + for (size_t i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) { auto val = _mm256_loadu_si256((__m256i*)((char*)from + i)); _mm256_storeu_si256((__m256i*)((char*)to + i), val); } // process remaining part - for (int i = aligned_bytes; i < n_bytes; i++) { + for (size_t i = aligned_bytes; i < n_bytes; i++) { *((char*)to + i) = *((char*)from + i); } } @@ -481,7 +482,9 @@ void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, siz static int current_buffer = 0; static int state_idx = 0; - enum coll_state copy_current, copy_next; + // init states to case 0 to get rid of "maybe-uninitialized" warning. + enum coll_state copy_current = coll_allreduce_naive__copy_in_done; + enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done; switch (state_idx) { case 0: @@ -526,7 +529,10 @@ void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_ static int current_buffer = 0; static int state_idx = 0; - enum coll_state copy_current, copy_next, reduce_current; + // init states to case 0 to get rid of "maybe-uninitialized" warning. + enum coll_state copy_current = coll_allreduce_naive__copy_in_done; + enum coll_state reduce_current = coll_allreduce_naive__reduce_done; + enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done; // similar to symmetric_naive_allreduce, but here we only need two sets of // states, because distributed naive reduce has two barriers in the algorithm @@ -601,7 +607,9 @@ void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_ static int current_buffer = 0; static int state_idx = 0; - enum coll_state copy_current, copy_next; + // init states to case 0 to get rid of "maybe-uninitialized" warning. + enum coll_state copy_current = coll_allgather_naive__copy_in_done; + enum coll_state copy_next = coll_alt1_allgather_naive__copy_in_done; switch (state_idx) { case 0: @@ -621,7 +629,6 @@ void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_ } state_idx = (state_idx + 1) % 3; - int data_size = chunk_size / chunk_el; parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); std::atomic_thread_fence(std::memory_order_release); workspace[world_rank]->states[state_group] = copy_current; @@ -644,7 +651,7 @@ torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, s auto data_ptr = (char*)(data.data_ptr()); auto result_ptr = (char*)(result.data_ptr()); for (int i = 0; i < dim_count; i++) { - for (int offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) { + for (size_t offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) { size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset; size_t chunk_el = chunk_size / dtype_size; naive_all_gather( diff --git a/sgl-kernel/csrc/cpu/shm.h b/sgl-kernel/csrc/cpu/shm.h index 4419222a1..3e903972c 100644 --- a/sgl-kernel/csrc/cpu/shm.h +++ b/sgl-kernel/csrc/cpu/shm.h @@ -5,7 +5,7 @@ #ifndef __SHM_COLLECTIVES__ #define __SHM_COLLECTIVES__ #define VECTOR_LENGTH_IN_BYTES 32 -void shm_initialize(int size, int rank, char* addr_string, char* port_string); +void shm_initialize(int size, int rank, const char* addr_string, const char* port_string); void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size); torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size); #endif diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp index da8639a35..cdfa4c271 100644 --- a/sgl-kernel/csrc/cpu/topk.cpp +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -534,7 +534,25 @@ std::tuple grouped_topk_cpu( int64_t topk, bool renormalize, int64_t num_expert_group, - int64_t topk_group) { + int64_t topk_group, + int64_t num_fused_shared_experts, + std::optional routed_scaling_factor, + std::optional num_token_non_padded) { + // TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded. + // For now, we just check them as default value. + TORCH_CHECK( + num_fused_shared_experts == 0, + "num_fused_shared_experts must be 0 default value, got: ", + num_fused_shared_experts); + TORCH_CHECK( + !routed_scaling_factor.has_value() || routed_scaling_factor.value() == 1.0f, + "routed_scaling_factor must be None or 1.0f default value, got: ", + routed_scaling_factor.value()); + TORCH_CHECK( + !num_token_non_padded.has_value(), + "num_token_non_padded must be None default value, got: ", + num_token_non_padded.value()); + RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector({hidden_states, gating_output})); CHECK_INPUT(gating_output); @@ -594,7 +612,21 @@ std::tuple biased_grouped_topk_cpu( int64_t topk, bool renormalize, int64_t num_expert_group, - int64_t topk_group) { + int64_t topk_group, + int64_t num_fused_shared_experts, + std::optional routed_scaling_factor, + std::optional num_token_non_padded) { + // TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded. + // For now, we just check them as default value. + TORCH_CHECK( + num_fused_shared_experts == 0, + "num_fused_shared_experts must be 0 default value, got: ", + num_fused_shared_experts); + TORCH_CHECK( + !num_token_non_padded.has_value(), + "num_token_non_padded must be None default value, got: ", + num_token_non_padded.value()); + RECORD_FUNCTION( "sgl-kernel::biased_grouped_topk_cpu", std::vector({hidden_states, gating_output, correction_bias})); diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index b718fc1b1..7c26c354f 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -44,7 +44,10 @@ std::tuple grouped_topk_cpu( int64_t topk, bool renormalize, int64_t num_expert_group, - int64_t topk_group); + int64_t topk_group, + int64_t num_fused_shared_experts, + std::optional routed_scaling_factor, + std::optional num_token_non_padded); std::tuple biased_grouped_topk_cpu( at::Tensor& hidden_states, @@ -53,7 +56,10 @@ std::tuple biased_grouped_topk_cpu( int64_t topk, bool renormalize, int64_t num_expert_group, - int64_t topk_group); + int64_t topk_group, + int64_t num_fused_shared_experts, + std::optional routed_scaling_factor, + std::optional num_token_non_padded); // attention void decode_attention_cpu( @@ -182,6 +188,26 @@ std::tuple qkv_proj_with_rope( bool is_vnni, std::optional> block_size); +std::tuple qkv_proj_with_rope_fused_weight( + at::Tensor& hidden_states, + at::Tensor& qkv_a_proj_weight, + at::Tensor& q_b_proj_weight, + at::Tensor& w_kc, + at::Tensor& q_a_layernorm_weight, + at::Tensor& kv_a_layernorm_weight, + at::Tensor& positions, + at::Tensor& cos_sin_cache, + double eps, + bool use_int8_w8a8, + bool use_fp8_w8a16, + std::optional qkv_a_proj_scale, + std::optional q_b_proj_scale, + bool is_vnni, + std::optional> block_size, + int64_t q_lora_rank, + int64_t kv_lora_rank, + int64_t qk_rope_head_dim); + // shared memory init void initialize(int64_t size, int64_t rank); @@ -221,13 +247,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu); m.def( "grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, " - "int topk_group) -> (Tensor, Tensor)"); + "int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, Tensor? num_token_non_padded) -> " + "(Tensor, Tensor)"); m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu); // biased group topk m.def( "biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool " - "renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)"); + "renormalize, int num_expert_group, int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, " + "Tensor? num_token_non_padded) -> (Tensor, Tensor)"); m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu); // decode @@ -294,6 +322,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "q_b_proj_scale, Tensor? " "kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)"); m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope); + m.def( + "qkv_proj_with_rope_fused_weight(Tensor hidden_states, Tensor qkv_a_proj_weight, Tensor q_b_proj_weight, " + "Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, " + "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? qkv_a_proj_scale, Tensor? " + "q_b_proj_scale," + "bool is_vnni, int[]? block_size, int q_lora_rank, int kv_lora_rank," + "int qk_rope_head_dim) -> (Tensor, Tensor, Tensor)"); + m.impl("qkv_proj_with_rope_fused_weight", torch::kCPU, &qkv_proj_with_rope_fused_weight); // shared expert m.def( diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h index 78e8b8f17..9f8eaad18 100644 --- a/sgl-kernel/csrc/cpu/vec.h +++ b/sgl-kernel/csrc/cpu/vec.h @@ -30,6 +30,22 @@ convert_from_float_ext(const Vectorized& a, const Vectorize #define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) +// this doesn't handle NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { // The following conversion is without denorm behavior, that is to say, // Max subnorm : S.0000.111 = 0.875 āˆ— 2**(āˆ’6) @@ -84,7 +100,7 @@ inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { inline __m512bh CVT_FP8_TO_BF16(__m256i a) { #ifdef SGLANG_CPU_FP8_CVT_FTZ - return cvt_e4m3_bf16_intrinsic_without_denorm(a); + return cvt_e4m3_bf16_intrinsic_no_nan(a); #else return cvt_e4m3_bf16_intrinsic_with_denorm(a); #endif @@ -172,4 +188,102 @@ inline void quantize_row_int8( } #endif +// transpose utils +// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998 +#if defined(CPU_CAPABILITY_AVX512) +inline void transpose_16x16_32bit(__m512i* v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +// remove warning : ignoring attributes on template argument ā€˜__m512i’ [-Wignored-attributes] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" + +// transpose from [2, 32] to [32, 2] +inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) { + // r0: {a0, a1, ..., a31} + // r1: {b0, b1, ..., b31} + // + // d0: {a0, b0, ..., a15, b15} + // d1: {a16, b16, ..., a31, b31} + // + __m512i d0 = _mm512_unpacklo_epi16(r0, r1); + __m512i d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + return std::make_tuple(d0, d1); +} +#pragma GCC diagnostic pop + +#endif + } // anonymous namespace diff --git a/test/srt/cpu/test_mla.py b/test/srt/cpu/test_mla.py new file mode 100644 index 000000000..217e33b71 --- /dev/null +++ b/test/srt/cpu/test_mla.py @@ -0,0 +1,155 @@ +import itertools +import unittest + +import sgl_kernel +import torch +from torch.nn.functional import scaled_dot_product_attention +from utils import precision + +from sglang.test.test_utils import CustomTestCase + + +class TestMLA(CustomTestCase): + def _run_sdpa_forward_decode( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + key: torch.Tensor, + loc: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + # set kv cache + k_cache[loc] = key + + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + seq_len_q = 1 + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out = ( + scaled_dot_product_attention( + per_req_query.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + enable_gqa=enable_gqa, + scale=scaling, + is_causal=causal, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out + start_q, start_kv = end_q, end_kv + + return output + + def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len): + dtype = torch.bfloat16 + + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + logit_cap = 0.0 + num_kv_splits = 8 + enable_gqa = H_Q != H_KV + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype) + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype) + v_buffer = k_buffer.narrow(2, 0, D_V) + + key = torch.randn(B, H_KV, D, dtype=dtype) + value = key.narrow(2, 0, D_V) + # make sure no duplicates in loc + loc = torch.randperm(total_tokens)[:B].to(torch.int64) + + k_buffer2 = k_buffer.clone() + v_buffer2 = k_buffer2.narrow(2, 0, D_V) + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D_V, dtype=dtype) + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype) + + req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32) + b_req_idx = torch.arange(B).to(torch.int64) + b_seq_len = torch.full((B,), seq_len).to(torch.int64) + + attn_logits = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + ) + + torch.ops.sgl_kernel.decode_attention_cpu( + q, + k_buffer2, + v_buffer2, + o, + key, + value, + loc, + attn_logits, + req_to_token, + b_req_idx, + b_seq_len, + sm_scale, + logit_cap, + ) + + self._run_sdpa_forward_decode( + q, + o_grouped, + k_buffer, + v_buffer, + key, + loc, + req_to_token, + b_req_idx, + b_seq_len, + scaling=sm_scale, + enable_gqa=enable_gqa, + ) + + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_grouped.flatten(), dim=0 + ) + atol = rtol = precision[q.dtype] + self.assertGreater(cos_sim.item(), 0.99) + torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol) + torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol) + torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol) + + def test_grouped_decode_attention(self): + configs = [ + (1, 22, 1, 576, 512, 8 * 111), + (4, 22, 1, 576, 512, 8 * 128), + (40, 22, 1, 576, 512, 8 * 133), + ] + + for B, H_Q, H_KV, D, D_V, seqlen in configs: + self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/test_moe.py b/test/srt/cpu/test_moe.py index 098f72cf1..62542e366 100644 --- a/test/srt/cpu/test_moe.py +++ b/test/srt/cpu/test_moe.py @@ -33,7 +33,7 @@ def fused_moe(a, w1, w2, score, topk, renormalize, prepack): topk_weights = torch.empty(B, topk, dtype=torch.float32) topk_ids = torch.empty(B, topk, dtype=torch.int32) topk_weights, topk_ids = kernel.grouped_topk_cpu( - a, score, topk, renormalize, G, topk_group + a, score, topk, renormalize, G, topk_group, 0, None, None ) packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 diff --git a/test/srt/cpu/test_norm.py b/test/srt/cpu/test_norm.py index b7d139a5b..fa4530afd 100644 --- a/test/srt/cpu/test_norm.py +++ b/test/srt/cpu/test_norm.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union import sgl_kernel import torch -from utils import precision +from utils import make_non_contiguous, precision from sglang.test.test_utils import CustomTestCase @@ -38,6 +38,7 @@ class TestNorm(CustomTestCase): def _norm_test(self, m, n, dtype): x = torch.randn([m, n], dtype=dtype) + x = make_non_contiguous(x) hidden_size = x.size(-1) weight = torch.randn(hidden_size, dtype=dtype) variance_epsilon = 1e-6 @@ -49,7 +50,7 @@ class TestNorm(CustomTestCase): self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) ref_x = x.clone() - residual = torch.randn([m, n], dtype=dtype) + residual = torch.randn([m, hidden_size], dtype=dtype) ref_residual = residual.clone() torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( diff --git a/test/srt/cpu/test_qkv_proj_with_rope.py b/test/srt/cpu/test_qkv_proj_with_rope.py index 0d2f7d940..9d4b80f6a 100644 --- a/test/srt/cpu/test_qkv_proj_with_rope.py +++ b/test/srt/cpu/test_qkv_proj_with_rope.py @@ -14,6 +14,7 @@ from sglang.test.test_utils import CustomTestCase convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope +qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight torch.manual_seed(0) # constants kv_lora_rank = 512 @@ -148,6 +149,7 @@ class TestQKVProjWithROPE(CustomTestCase): kv_a_proj_weight = ( torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 ) + fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0) norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) pos = torch.randint(10, 100, (B,)) cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) @@ -167,6 +169,7 @@ class TestQKVProjWithROPE(CustomTestCase): qb_packed = convert_weight_packed(q_b_proj_weight) kva_packed = convert_weight_packed(kv_a_proj_weight) wkc_packed = convert_weight_packed(w_kc) + fused_weight_packed = convert_weight_packed(fused_weight) q_out, k_out, v_out = qkv_proj_with_rope( hidden_states, @@ -187,10 +190,33 @@ class TestQKVProjWithROPE(CustomTestCase): True, None, ) + fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( + hidden_states, + fused_weight_packed, + qb_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + False, + None, + None, + True, + None, + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + ) atol = rtol = precision[q_ref.dtype] self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(fused_q_out, q_out)) + self.assertTrue(torch.allclose(fused_k_out, k_out)) + self.assertTrue(torch.allclose(fused_v_out, v_out)) def test_int8_qkv_proj_with_rope(self): dtype = torch.bfloat16 @@ -252,10 +278,36 @@ class TestQKVProjWithROPE(CustomTestCase): True, None, ) + fused_weight = torch.cat([w1_q, w3_q], dim=0) + fused_weight_s = torch.cat([w1_s, w3_s], dim=0) + w_fused_q_packed = convert_weight_packed(fused_weight) + fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( + hidden_states, + w_fused_q_packed, + w2_q_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + True, + False, + fused_weight_s, + w2_s, + True, + None, + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + ) atol = rtol = precision[q_ref.dtype] self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(fused_q_out, q_out)) + self.assertTrue(torch.allclose(fused_k_out, k_out)) + self.assertTrue(torch.allclose(fused_v_out, v_out)) def test_fp8_qkv_proj_with_rope(self): dtype = torch.bfloat16 @@ -311,17 +363,17 @@ class TestQKVProjWithROPE(CustomTestCase): pos, cos_sin_cache, ) - fp8_q_a_proj_weight = convert_weight_packed(fp8_q_a_proj_weight) - fp8_q_b_proj_weight = convert_weight_packed(fp8_q_b_proj_weight) - fp8_kv_a_proj_with_mqa_weight = convert_weight_packed( + fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight) + fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight) + fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed( fp8_kv_a_proj_with_mqa_weight ) w_kc = convert_weight_packed(w_kc) q_out, k_out, v_out = qkv_proj_with_rope( hidden_states, - fp8_q_a_proj_weight, - fp8_q_b_proj_weight, - fp8_kv_a_proj_with_mqa_weight, + fp8_q_a_proj_weight_packed, + fp8_q_b_proj_weight_packed, + fp8_kv_a_proj_with_mqa_weight_packed, w_kc, norm_weight1, norm_weight2, @@ -336,10 +388,44 @@ class TestQKVProjWithROPE(CustomTestCase): True, [scale_block_size_N, scale_block_size_K], ) + + fused_weight = torch.cat( + [fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0 + ) + fused_weight_s = torch.cat( + [q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0 + ) + fused_weight_packed = convert_weight_packed(fused_weight) + fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight( + hidden_states, + fused_weight_packed, + fp8_q_b_proj_weight_packed, + w_kc, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + True, + fused_weight_s.float(), + q_b_proj_weight_scale_inv.float(), + True, + [scale_block_size_N, scale_block_size_K], + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + ) atol = rtol = precision[q_ref.dtype] - self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) - self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) + # Due to the change in multiplication order, the error is amplified. + # In the model, with fewer layers, this doesn't cause issues, but in + # tests with more layers, we need to enlarge the tolerance to pass the tests. + torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol) + torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_q_out, q_out) + torch.testing.assert_close(fused_k_out, k_out) + torch.testing.assert_close(fused_v_out, v_out) if __name__ == "__main__": diff --git a/test/srt/cpu/test_topk.py b/test/srt/cpu/test_topk.py index 3d0138d9a..420f6cbb7 100644 --- a/test/srt/cpu/test_topk.py +++ b/test/srt/cpu/test_topk.py @@ -34,7 +34,15 @@ class TestGroupedTopK(CustomTestCase): # fused version topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu( - hidden_states, gating_output, topk, renormalize, G, topk_group + hidden_states, + gating_output, + topk, + renormalize, + G, + topk_group, + 0, + None, + None, ) res = torch.zeros(M, E, dtype=torch.float) @@ -83,6 +91,9 @@ class TestBiasedGroupedTopK(CustomTestCase): renormalize, G, topk_group, + 0, + None, + None, ) res = torch.zeros(M, E, dtype=torch.float) diff --git a/test/srt/cpu/utils.py b/test/srt/cpu/utils.py index 1716782fe..3a4e44aa1 100644 --- a/test/srt/cpu/utils.py +++ b/test/srt/cpu/utils.py @@ -244,3 +244,11 @@ def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk): .sum(dim=1) .to(a.dtype) ) + + +def make_non_contiguous(x: torch.Tensor) -> torch.Tensor: + """ + Make a tensor non-contiguous by slicing it via last dimension. + """ + last_dim = x.shape[-1] + return x[..., : last_dim // 2] if x.is_contiguous() else x