From 13fb8b5489447c7e70a3e9b13793bc5bceda9c49 Mon Sep 17 00:00:00 2001 From: blzheng Date: Thu, 23 Oct 2025 12:39:51 +0800 Subject: [PATCH] [CPU] Optimize FP16 decode_attention_cpu (#10652) --- .../srt/layers/vocab_parallel_embedding.py | 5 +- sgl-kernel/csrc/cpu/decode.cpp | 173 +++++++++++++++++- sgl-kernel/csrc/cpu/vec.h | 2 +- test/srt/cpu/test_decode.py | 10 +- 4 files changed, 181 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 66abb7541..986babf2b 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding): # We only support pack LMHead if it's not quantized. if _is_cpu and _is_cpu_amx_available: - if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16: + if hasattr(self, "weight") and self.weight.dtype in [ + torch.bfloat16, + torch.float16, + ]: self.quant_method = PackWeightMethod(weight_names=["weight"]) if bias: diff --git a/sgl-kernel/csrc/cpu/decode.cpp b/sgl-kernel/csrc/cpu/decode.cpp index ae5ac51c8..3de2708e7 100644 --- a/sgl-kernel/csrc/cpu/decode.cpp +++ b/sgl-kernel/csrc/cpu/decode.cpp @@ -308,6 +308,93 @@ struct tinygemm_kernel_nt { }; #endif +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nt { + static inline void apply( + const at::Half* __restrict__ A, + const at::Half* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + float scale, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t K, + int64_t max_tokens) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N; + + __m512 va0, va1; + __m512 vb0[COLS], vb1[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vscale = _mm512_set1_ps(scale); + + auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); }; + Unroll{}(loadc); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + __m512i a16 = _mm512_loadu_si512((__m512i const*)(A + row * lda + k)); + va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0)); + va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1)); + } + + if constexpr (row == 0) { + int64_t b_idx = indices[col]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + __m512i b16 = _mm512_loadu_si512((__m512i const*)(B + b_idx * ldb + k)); + vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0)); + vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1)); + } + + vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i])); + }; + + auto compute2 = [&](auto i, int64_t k, __mmask32 mask) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + __m512i a16 = _mm512_maskz_loadu_epi16(mask, (const void*)(A + row * lda + k)); + va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0)); + va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1)); + } + + if constexpr (row == 0) { + int64_t b_idx = indices[col]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + __m512i b16 = _mm512_maskz_loadu_epi16(mask, (const void*)(B + b_idx * ldb + k)); + vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0)); + vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1)); + } + + vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i])); + }; + + int64_t k = 0; + for (; k <= K - 32; k += 32) { + Unroll{}(compute, k); + } + int64_t count = K - k; + if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + Unroll{}(compute2, k, mask); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale)); + }; + Unroll{}(storec); + } +}; +#endif + #define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \ tinygemm_kernel_nt::apply( \ A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens); @@ -443,6 +530,87 @@ struct tinygemm_kernel_nn { }; #endif +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const float* __restrict__ A, + const at::Half* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + const float* __restrict__ scale, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t K, + int64_t max_tokens) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vscale; + + auto loadc = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" + if constexpr (col == 0) { + vscale = _mm512_set1_ps(scale[row]); + } +#pragma GCC diagnostic pop + vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16); + vc[i] = _mm512_mul_ps(vc[i], vscale); + }; + Unroll{}(loadc); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_ps(A[row * lda + k]); + } + if constexpr (row == 0) { + if (k + 1 < K) { + int64_t b_idx_prefetch = indices[k + 1]; + _mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0); + } + int64_t b_idx = indices[k]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + + // for COLS = 2, 4, 6, 8 use 512 bit load + // for COLS = 1, 3, 5, 7 use 256 bit load + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + __m512i b16 = _mm512_loadu_si512(reinterpret_cast(B + b_idx * ldb + col * 16)); + vb[col + 0] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0)); + vb[col + 1] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1)); + } + } else { + __m256i b16 = _mm256_loadu_si256(reinterpret_cast(B + b_idx * ldb + col * 16)); + vb[col] = CVT_FP16_TO_FP32(b16); + } + } + vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]); + }; + + for (int64_t k = 0; k < K; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]); + }; + Unroll{}(storec); + } +}; +#endif + #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ tinygemm_kernel_nn::apply( \ A + mb_start * lda, \ @@ -512,9 +680,10 @@ void index_gemm_kernel_nt( return; } - // pattern: 1-6-24 + // default pattern: 1-6-24 + // FP16 pattern: 2-8-16 constexpr int64_t BLOCK_M = 4; - constexpr int64_t BLOCK_N = 6; + constexpr int64_t BLOCK_N = std::is_same_v ? 4 : 6; const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h index d28124c1d..d12cd8c6e 100644 --- a/sgl-kernel/csrc/cpu/vec.h +++ b/sgl-kernel/csrc/cpu/vec.h @@ -47,7 +47,7 @@ convert_from_float_ext(const Vectorized& a, const Vectorize #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) -#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) +#define CVT_FP16_TO_FP32(a) _mm512_cvtph_ps(a) // this doesn't handle NaN. inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { diff --git a/test/srt/cpu/test_decode.py b/test/srt/cpu/test_decode.py index c77378e1a..a7c5dd755 100644 --- a/test/srt/cpu/test_decode.py +++ b/test/srt/cpu/test_decode.py @@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase): return output - def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device): - dtype = torch.bfloat16 + def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, dtype, device): # This represents the number of tokens already in the sequence seq_len = 1024 total_tokens = B * seq_len @@ -158,9 +157,10 @@ class TestDecodeAttention(CustomTestCase): ] for B, H_Q, H_KV, D, D_V in configs: - self._test_grouped_decode_attention_once( - B, H_Q, H_KV, D, D_V, device=device - ) + for dtype in [torch.bfloat16, torch.float16]: + self._test_grouped_decode_attention_once( + B, H_Q, H_KV, D, D_V, dtype=dtype, device=device + ) def test_grouped_decode_attention(self): self._test_grouped_decode_attention("cpu")