diff --git a/sgl-kernel/csrc/cpu/qkv_proj.cpp b/sgl-kernel/csrc/cpu/qkv_proj.cpp index 1a5361941..82c4d6583 100644 --- a/sgl-kernel/csrc/cpu/qkv_proj.cpp +++ b/sgl-kernel/csrc/cpu/qkv_proj.cpp @@ -152,6 +152,85 @@ void segment_gemm_kernel_impl( }); } +// [C0, C1] = A @ [B0, B1] +template +void segment_gemm_kernel_impl( + scalar_t* __restrict__ C0, + scalar_t* __restrict__ C1, + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B0, + const at::Float8_e4m3fn* __restrict__ B1, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N0, + int64_t N1, + int64_t K, + int64_t block_size_N, + int64_t block_size_K) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB0 = div_up(N0, BLOCK_N); + const int64_t NB1 = div_up(N1, BLOCK_N); + const int64_t NB = NB0 + NB1; + + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [MB, NB0 + NB1] + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + // for brgemm when mat2 is float8_e4m3 + alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = BLOCK_N; + + const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1; + const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1; + scalar_t* __restrict__ C = nb < NB0 ? C0 : C1; + int64_t ldc = nb < NB0 ? N0 : N1; + int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0; + int64_t new_nb = nb < NB0 ? nb : nb - NB0; + + tinygemm_kernel( + /* A */ A + mb_start * K, + /* B */ B + local_nb_start * K /* nb * BLOCK_N * K */, + /* C */ C + mb_start * ldc + local_nb_start, + /* Btmp*/ Btmp, + /* Ctmp*/ Ctmp, + /* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ ldc, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); +} + template inline float reduce(const scalar_t* __restrict__ x, int64_t size) { using bVec = at::vec::Vectorized; @@ -321,6 +400,15 @@ extern at::Tensor int8_scaled_mm_with_quant( extern void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional& scale); +extern at::Tensor fp8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::vector block_size, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); + // NB: shapes in DeepDeek R1 // // hidden_states : [num_seqs, hidden_size] [1, 7168] @@ -343,10 +431,12 @@ std::tuple qkv_proj_with_rope( at::Tensor& cos_sin_cache, double eps, bool use_int8_w8a8, + bool use_fp8_w8a16, std::optional q_a_proj_scale, std::optional q_b_proj_scale, std::optional kv_a_proj_scale, - bool is_vnni) { + bool is_vnni, + std::optional> block_size) { RECORD_FUNCTION( "sgl-kernel::qkv_proj_with_rope", std::vector({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc})); @@ -394,7 +484,13 @@ std::tuple qkv_proj_with_rope( TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8."); TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8."); } - + if (use_fp8_w8a16) { + TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16."); + TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16."); + TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16."); + TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); + TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16."); + } // outputs and temp buffer const auto options = hidden_states.options(); auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options); @@ -436,6 +532,29 @@ std::tuple qkv_proj_with_rope( q_lora_rank, kv_lora_rank + qk_rope_head_dim, hidden_size); + } else if (use_fp8_w8a16) { + int64_t block_size_N = block_size.value()[0]; + int64_t block_size_K = block_size.value()[1]; + auto q_a_proj_s = q_a_proj_scale.value(); + auto kv_a_proj_s = kv_a_proj_scale.value(); + CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N)); + CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K)); + CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N)); + CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K)); + segment_gemm_kernel_impl( + qa.data_ptr(), + k_input.data_ptr(), + hidden_states.data_ptr(), + q_a_proj_weight.data_ptr(), + kv_a_proj_weight.data_ptr(), + q_a_proj_s.data_ptr(), + kv_a_proj_s.data_ptr(), + num_seqs, + q_lora_rank, + kv_lora_rank + qk_rope_head_dim, + hidden_size, + block_size_N, + block_size_K); } else { segment_gemm_kernel_impl( qa.data_ptr(), @@ -469,6 +588,9 @@ std::tuple qkv_proj_with_rope( std::optional bias; if (use_int8_w8a8) { qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni); + } else if (use_fp8_w8a16) { + qb = fp8_scaled_mm_cpu( + qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni); } else { qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni); } diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index bfc367606..f8e9a4559 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -165,10 +165,12 @@ std::tuple qkv_proj_with_rope( at::Tensor& cos_sin_cache, double eps, bool use_int8_w8a8, + bool use_fp8_w8a16, std::optional q_a_proj_scale, std::optional q_b_proj_scale, std::optional kv_a_proj_scale, - bool is_vnni); + bool is_vnni, + std::optional> block_size); // shared memory init void initialize(int64_t size, int64_t rank); @@ -209,8 +211,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // decode m.def( - "decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor " - "req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()"); + "decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, " + "Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, " + "float logit_cap) -> ()"); m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu); // extend @@ -265,8 +268,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor " "kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, " - "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? " - "kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)"); + "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? " + "q_b_proj_scale, Tensor? " + "kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)"); m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope); // shared expert diff --git a/test/srt/cpu/test_decode.py b/test/srt/cpu/test_decode.py index 1ab1bfae8..7e15a58aa 100644 --- a/test/srt/cpu/test_decode.py +++ b/test/srt/cpu/test_decode.py @@ -1,7 +1,7 @@ import unittest +import sgl_kernel import torch -from sgl_kernel.common_ops import decode_attention_cpu as decode_attention from torch.nn.functional import scaled_dot_product_attention from sglang.test.test_utils import CustomTestCase @@ -105,7 +105,7 @@ class TestDecodeAttention(CustomTestCase): v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) key = key.transpose(0, 1).contiguous().transpose(0, 1) value = value.transpose(0, 1).contiguous().transpose(0, 1) - decode_attention( + torch.ops.sgl_kernel.decode_attention_cpu( q, k_buffer, v_buffer, diff --git a/test/srt/cpu/test_extend.py b/test/srt/cpu/test_extend.py index 35fbfc184..c119c1524 100644 --- a/test/srt/cpu/test_extend.py +++ b/test/srt/cpu/test_extend.py @@ -1,7 +1,7 @@ import unittest +import sgl_kernel import torch -from sgl_kernel.common_ops import extend_attention_cpu as extend_attention from torch.nn.functional import scaled_dot_product_attention from sglang.test.test_utils import CustomTestCase @@ -157,7 +157,7 @@ class TestExtendAttention(CustomTestCase): ) o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) - extend_attention( + torch.ops.sgl_kernel.extend_attention_cpu( q_extend, k_extend, v_extend, diff --git a/test/srt/cpu/test_qkv_proj_with_rope.py b/test/srt/cpu/test_qkv_proj_with_rope.py new file mode 100644 index 000000000..0d2f7d940 --- /dev/null +++ b/test/srt/cpu/test_qkv_proj_with_rope.py @@ -0,0 +1,346 @@ +import unittest + +import sgl_kernel +import torch +from utils import ( + convert_weight, + native_w8a8_per_token_matmul, + per_token_quant_int8, + precision, +) + +from sglang.srt.layers.rotary_embedding import _apply_rotary_emb +from sglang.test.test_utils import CustomTestCase + +convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed +qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope +torch.manual_seed(0) +# constants +kv_lora_rank = 512 +qk_head_dim = 192 +qk_nope_head_dim = 128 +qk_rope_head_dim = 64 +rotary_dim = qk_rope_head_dim +num_heads = 22 +q_lora_rank = 1536 +hidden_size = 7168 +B = 1 +eps = 1e-6 + + +def layernorm(x, weight, variance_epsilon=1e-6, residual=None): + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + return (x * weight).to(orig_dtype) + + +def rotary_emb(q_pe, k_pe, pos, cos_sin_cache): + orig_dtype = q_pe.dtype + q_pe = q_pe.float() + k_pe = k_pe.float() + cos_sin_cache = cos_sin_cache.float() + + query_rot = q_pe[..., :rotary_dim] + key_rot = k_pe[..., :rotary_dim] + cos_sin = cos_sin_cache[pos] + cos, sin = cos_sin.chunk(2, dim=-1) + query_rot = _apply_rotary_emb(query_rot, cos, sin, False) + key_rot = _apply_rotary_emb(key_rot, cos, sin, False) + return query_rot.to(orig_dtype), key_rot.to(orig_dtype) + + +def native_torch( + q_input, + hidden_states, + q_a_proj_weight, + norm_weight1, + q_b_proj_weight, + w_kc, + kv_a_proj_weight, + norm_weight2, + pos, + cos_sin_cache, +): + + q = torch.matmul(hidden_states, q_a_proj_weight.t()) + q = layernorm(q, norm_weight1) + q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim) + + q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc) + + q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1) + latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t()) + v_input = latent_cache[..., :kv_lora_rank] + v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1) + k_input = latent_cache.unsqueeze(1) + k_input[..., :kv_lora_rank] = v_input + k_pe = k_input[..., kv_lora_rank:] + + q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache) + q_input[..., kv_lora_rank:] = q_pe + k_input[..., kv_lora_rank:] = k_pe + + return q_input, k_input, v_input + + +def native_torch_int8( + q_input, + hidden_states, + w1_q, + w1_s, + norm_weight1, + w2_q, + w2_s, + w_kc, + w3_q, + w3_s, + norm_weight2, + pos, + cos_sin_cache, +): + + a_q, a_s = per_token_quant_int8(hidden_states) + q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16) + q = layernorm(q, norm_weight1) + + a_q, a_s = per_token_quant_int8(q) + q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view( + -1, num_heads, qk_head_dim + ) + + q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc) + + q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1) + a_q, a_s = per_token_quant_int8(hidden_states) + latent_cache = native_w8a8_per_token_matmul( + a_q, w3_q, a_s, w3_s, None, torch.bfloat16 + ) + v_input = latent_cache[..., :kv_lora_rank] + v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1) + k_input = latent_cache.unsqueeze(1) + k_input[..., :kv_lora_rank] = v_input + k_pe = k_input[..., kv_lora_rank:] + + q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache) + q_input[..., kv_lora_rank:] = q_pe + k_input[..., kv_lora_rank:] = k_pe + + return q_input, k_input, v_input + + +class TestQKVProjWithROPE(CustomTestCase): + def test_bf16_qkv_proj_with_rope(self): + dtype = torch.bfloat16 + hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size + q_input = torch.empty( + B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype + ) + q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 + norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) + q_b_proj_weight = ( + torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 + ) + w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 + kv_a_proj_weight = ( + torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 + ) + norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) + pos = torch.randint(10, 100, (B,)) + cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) + q_ref, k_ref, v_ref = native_torch( + q_input, + hidden_states, + q_a_proj_weight, + norm_weight1, + q_b_proj_weight, + w_kc.transpose(1, 2), + kv_a_proj_weight, + norm_weight2, + pos, + cos_sin_cache, + ) + qa_packed = convert_weight_packed(q_a_proj_weight) + qb_packed = convert_weight_packed(q_b_proj_weight) + kva_packed = convert_weight_packed(kv_a_proj_weight) + wkc_packed = convert_weight_packed(w_kc) + + q_out, k_out, v_out = qkv_proj_with_rope( + hidden_states, + qa_packed, + qb_packed, + kva_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + False, + None, + None, + None, + True, + None, + ) + atol = rtol = precision[q_ref.dtype] + self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) + + def test_int8_qkv_proj_with_rope(self): + dtype = torch.bfloat16 + hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size + q_input = torch.empty( + B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype + ) + q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 + norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) + q_b_proj_weight = ( + torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 + ) + w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 + kv_a_proj_weight = ( + torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 + ) + norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) + pos = torch.randint(10, 100, (B,)) + cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) + + w1_q, w1_s = per_token_quant_int8(q_a_proj_weight) + w2_q, w2_s = per_token_quant_int8(q_b_proj_weight) + w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight) + q_ref, k_ref, v_ref = native_torch_int8( + q_input, + hidden_states, + w1_q, + w1_s, + norm_weight1, + w2_q, + w2_s, + w_kc.transpose(1, 2), + w3_q, + w3_s, + norm_weight2, + pos, + cos_sin_cache, + ) + w1_q_packed = convert_weight_packed(w1_q) + w2_q_packed = convert_weight_packed(w2_q) + w3_q_packed = convert_weight_packed(w3_q) + wkc_packed = convert_weight_packed(w_kc) + q_out, k_out, v_out = qkv_proj_with_rope( + hidden_states, + w1_q_packed, + w2_q_packed, + w3_q_packed, + wkc_packed, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + True, + False, + w1_s, + w2_s, + w3_s, + True, + None, + ) + atol = rtol = precision[q_ref.dtype] + self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) + + def test_fp8_qkv_proj_with_rope(self): + dtype = torch.bfloat16 + hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size + q_input = torch.empty( + B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype + ) + q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1 + norm_weight1 = torch.randn(q_lora_rank, dtype=dtype) + q_b_proj_weight = ( + torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1 + ) + w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1 + kv_a_proj_weight = ( + torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 + ) + norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) + pos = torch.randint(10, 100, (B,)) + cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) + + scale_block_size_N = 128 + scale_block_size_K = 128 + fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = ( + convert_weight( + q_a_proj_weight, + [scale_block_size_N, scale_block_size_K], + torch.bfloat16, + ) + ) + fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = ( + convert_weight( + q_b_proj_weight, + [scale_block_size_N, scale_block_size_K], + torch.bfloat16, + ) + ) + ( + fp8_kv_a_proj_with_mqa_weight, + kv_a_proj_with_mqa_weight_scale_inv, + kv_a_proj_with_mqa_weight_dq, + ) = convert_weight( + kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16 + ) + q_ref, k_ref, v_ref = native_torch( + q_input, + hidden_states, + q_a_proj_weight_dq, + norm_weight1, + q_b_proj_weight_dq, + w_kc.transpose(1, 2), + kv_a_proj_with_mqa_weight_dq, + norm_weight2, + pos, + cos_sin_cache, + ) + fp8_q_a_proj_weight = convert_weight_packed(fp8_q_a_proj_weight) + fp8_q_b_proj_weight = convert_weight_packed(fp8_q_b_proj_weight) + fp8_kv_a_proj_with_mqa_weight = convert_weight_packed( + fp8_kv_a_proj_with_mqa_weight + ) + w_kc = convert_weight_packed(w_kc) + q_out, k_out, v_out = qkv_proj_with_rope( + hidden_states, + fp8_q_a_proj_weight, + fp8_q_b_proj_weight, + fp8_kv_a_proj_with_mqa_weight, + w_kc, + norm_weight1, + norm_weight2, + pos, + cos_sin_cache, + eps, + False, + True, + q_a_proj_weight_scale_inv.float(), + q_b_proj_weight_scale_inv.float(), + kv_a_proj_with_mqa_weight_scale_inv.float(), + True, + [scale_block_size_N, scale_block_size_K], + ) + atol = rtol = precision[q_ref.dtype] + self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) + + +if __name__ == "__main__": + unittest.main()