Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT (#6493)
This commit is contained in:
@@ -152,6 +152,85 @@ void segment_gemm_kernel_impl(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// [C0, C1] = A @ [B0, B1]
|
||||||
|
template <typename scalar_t>
|
||||||
|
void segment_gemm_kernel_impl(
|
||||||
|
scalar_t* __restrict__ C0,
|
||||||
|
scalar_t* __restrict__ C1,
|
||||||
|
const scalar_t* __restrict__ A,
|
||||||
|
const at::Float8_e4m3fn* __restrict__ B0,
|
||||||
|
const at::Float8_e4m3fn* __restrict__ B1,
|
||||||
|
const float* __restrict__ Bs0,
|
||||||
|
const float* __restrict__ Bs1,
|
||||||
|
int64_t M,
|
||||||
|
int64_t N0,
|
||||||
|
int64_t N1,
|
||||||
|
int64_t K,
|
||||||
|
int64_t block_size_N,
|
||||||
|
int64_t block_size_K) {
|
||||||
|
constexpr int64_t BLOCK_M = block_size_m();
|
||||||
|
constexpr int64_t BLOCK_N = block_size_n();
|
||||||
|
const int64_t MB = div_up(M, BLOCK_M);
|
||||||
|
const int64_t NB0 = div_up(N0, BLOCK_N);
|
||||||
|
const int64_t NB1 = div_up(N1, BLOCK_N);
|
||||||
|
const int64_t NB = NB0 + NB1;
|
||||||
|
|
||||||
|
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||||
|
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||||
|
|
||||||
|
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||||
|
|
||||||
|
// parallel on [MB, NB0 + NB1]
|
||||||
|
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||||
|
int64_t mb{0}, nb{0};
|
||||||
|
data_index_init(begin, mb, MB, nb, NB);
|
||||||
|
|
||||||
|
// for brgemm, use float32 for accumulate
|
||||||
|
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||||
|
// for brgemm when mat2 is float8_e4m3
|
||||||
|
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
|
||||||
|
|
||||||
|
for (int64_t i = begin; i < end; ++i) {
|
||||||
|
UNUSED(i);
|
||||||
|
|
||||||
|
int mb_start = mb * BLOCK_M;
|
||||||
|
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||||
|
int nb_start = nb * BLOCK_N;
|
||||||
|
int nb_size = BLOCK_N;
|
||||||
|
|
||||||
|
const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1;
|
||||||
|
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
|
||||||
|
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
|
||||||
|
int64_t ldc = nb < NB0 ? N0 : N1;
|
||||||
|
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
|
||||||
|
int64_t new_nb = nb < NB0 ? nb : nb - NB0;
|
||||||
|
|
||||||
|
tinygemm_kernel<scalar_t>(
|
||||||
|
/* A */ A + mb_start * K,
|
||||||
|
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
|
||||||
|
/* C */ C + mb_start * ldc + local_nb_start,
|
||||||
|
/* Btmp*/ Btmp,
|
||||||
|
/* Ctmp*/ Ctmp,
|
||||||
|
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
|
||||||
|
/* M */ mb_size,
|
||||||
|
/* N */ nb_size,
|
||||||
|
/* K */ K,
|
||||||
|
/* lda */ K,
|
||||||
|
/* ldb */ nb_size,
|
||||||
|
/* ldc */ ldc,
|
||||||
|
/* brg */ use_brgemm,
|
||||||
|
/* block_size_K */ block_size_K);
|
||||||
|
|
||||||
|
// move to the next index
|
||||||
|
data_index_step(mb, MB, nb, NB);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_brgemm) {
|
||||||
|
at::native::cpublas::brgemm_release();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
|
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
|
||||||
using bVec = at::vec::Vectorized<scalar_t>;
|
using bVec = at::vec::Vectorized<scalar_t>;
|
||||||
@@ -321,6 +400,15 @@ extern at::Tensor int8_scaled_mm_with_quant(
|
|||||||
extern void
|
extern void
|
||||||
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
||||||
|
|
||||||
|
extern at::Tensor fp8_scaled_mm_cpu(
|
||||||
|
at::Tensor& mat1,
|
||||||
|
at::Tensor& mat2,
|
||||||
|
at::Tensor& scales2,
|
||||||
|
std::vector<int64_t> block_size,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
at::ScalarType out_dtype,
|
||||||
|
bool is_vnni);
|
||||||
|
|
||||||
// NB: shapes in DeepDeek R1
|
// NB: shapes in DeepDeek R1
|
||||||
//
|
//
|
||||||
// hidden_states : [num_seqs, hidden_size] [1, 7168]
|
// hidden_states : [num_seqs, hidden_size] [1, 7168]
|
||||||
@@ -343,10 +431,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
|||||||
at::Tensor& cos_sin_cache,
|
at::Tensor& cos_sin_cache,
|
||||||
double eps,
|
double eps,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
|
bool use_fp8_w8a16,
|
||||||
std::optional<at::Tensor> q_a_proj_scale,
|
std::optional<at::Tensor> q_a_proj_scale,
|
||||||
std::optional<at::Tensor> q_b_proj_scale,
|
std::optional<at::Tensor> q_b_proj_scale,
|
||||||
std::optional<at::Tensor> kv_a_proj_scale,
|
std::optional<at::Tensor> kv_a_proj_scale,
|
||||||
bool is_vnni) {
|
bool is_vnni,
|
||||||
|
std::optional<std::vector<int64_t>> block_size) {
|
||||||
RECORD_FUNCTION(
|
RECORD_FUNCTION(
|
||||||
"sgl-kernel::qkv_proj_with_rope",
|
"sgl-kernel::qkv_proj_with_rope",
|
||||||
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
|
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
|
||||||
@@ -394,7 +484,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
|||||||
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
|
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
|
||||||
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
|
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
|
||||||
}
|
}
|
||||||
|
if (use_fp8_w8a16) {
|
||||||
|
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16.");
|
||||||
|
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16.");
|
||||||
|
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16.");
|
||||||
|
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||||
|
TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16.");
|
||||||
|
}
|
||||||
// outputs and temp buffer
|
// outputs and temp buffer
|
||||||
const auto options = hidden_states.options();
|
const auto options = hidden_states.options();
|
||||||
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
|
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
|
||||||
@@ -436,6 +532,29 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
|||||||
q_lora_rank,
|
q_lora_rank,
|
||||||
kv_lora_rank + qk_rope_head_dim,
|
kv_lora_rank + qk_rope_head_dim,
|
||||||
hidden_size);
|
hidden_size);
|
||||||
|
} else if (use_fp8_w8a16) {
|
||||||
|
int64_t block_size_N = block_size.value()[0];
|
||||||
|
int64_t block_size_K = block_size.value()[1];
|
||||||
|
auto q_a_proj_s = q_a_proj_scale.value();
|
||||||
|
auto kv_a_proj_s = kv_a_proj_scale.value();
|
||||||
|
CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N));
|
||||||
|
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
|
||||||
|
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
|
||||||
|
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
|
||||||
|
segment_gemm_kernel_impl<scalar_t>(
|
||||||
|
qa.data_ptr<scalar_t>(),
|
||||||
|
k_input.data_ptr<scalar_t>(),
|
||||||
|
hidden_states.data_ptr<scalar_t>(),
|
||||||
|
q_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
|
||||||
|
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
|
||||||
|
q_a_proj_s.data_ptr<float>(),
|
||||||
|
kv_a_proj_s.data_ptr<float>(),
|
||||||
|
num_seqs,
|
||||||
|
q_lora_rank,
|
||||||
|
kv_lora_rank + qk_rope_head_dim,
|
||||||
|
hidden_size,
|
||||||
|
block_size_N,
|
||||||
|
block_size_K);
|
||||||
} else {
|
} else {
|
||||||
segment_gemm_kernel_impl<scalar_t>(
|
segment_gemm_kernel_impl<scalar_t>(
|
||||||
qa.data_ptr<scalar_t>(),
|
qa.data_ptr<scalar_t>(),
|
||||||
@@ -469,6 +588,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
|||||||
std::optional<at::Tensor> bias;
|
std::optional<at::Tensor> bias;
|
||||||
if (use_int8_w8a8) {
|
if (use_int8_w8a8) {
|
||||||
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
|
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
|
||||||
|
} else if (use_fp8_w8a16) {
|
||||||
|
qb = fp8_scaled_mm_cpu(
|
||||||
|
qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni);
|
||||||
} else {
|
} else {
|
||||||
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
|
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -165,10 +165,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
|||||||
at::Tensor& cos_sin_cache,
|
at::Tensor& cos_sin_cache,
|
||||||
double eps,
|
double eps,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
|
bool use_fp8_w8a16,
|
||||||
std::optional<at::Tensor> q_a_proj_scale,
|
std::optional<at::Tensor> q_a_proj_scale,
|
||||||
std::optional<at::Tensor> q_b_proj_scale,
|
std::optional<at::Tensor> q_b_proj_scale,
|
||||||
std::optional<at::Tensor> kv_a_proj_scale,
|
std::optional<at::Tensor> kv_a_proj_scale,
|
||||||
bool is_vnni);
|
bool is_vnni,
|
||||||
|
std::optional<std::vector<int64_t>> block_size);
|
||||||
|
|
||||||
// shared memory init
|
// shared memory init
|
||||||
void initialize(int64_t size, int64_t rank);
|
void initialize(int64_t size, int64_t rank);
|
||||||
@@ -209,8 +211,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
|
|
||||||
// decode
|
// decode
|
||||||
m.def(
|
m.def(
|
||||||
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
|
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
|
||||||
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
|
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
|
||||||
|
"float logit_cap) -> ()");
|
||||||
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
||||||
|
|
||||||
// extend
|
// extend
|
||||||
@@ -265,8 +268,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
|
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
|
||||||
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
|
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
|
||||||
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
|
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
|
||||||
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
|
"q_b_proj_scale, Tensor? "
|
||||||
|
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
|
||||||
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
|
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
|
||||||
|
|
||||||
// shared expert
|
// shared expert
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel.common_ops import decode_attention_cpu as decode_attention
|
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
@@ -105,7 +105,7 @@ class TestDecodeAttention(CustomTestCase):
|
|||||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
decode_attention(
|
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel.common_ops import extend_attention_cpu as extend_attention
|
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
@@ -157,7 +157,7 @@ class TestExtendAttention(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
||||||
extend_attention(
|
torch.ops.sgl_kernel.extend_attention_cpu(
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
v_extend,
|
v_extend,
|
||||||
|
|||||||
346
test/srt/cpu/test_qkv_proj_with_rope.py
Normal file
346
test/srt/cpu/test_qkv_proj_with_rope.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import sgl_kernel
|
||||||
|
import torch
|
||||||
|
from utils import (
|
||||||
|
convert_weight,
|
||||||
|
native_w8a8_per_token_matmul,
|
||||||
|
per_token_quant_int8,
|
||||||
|
precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.rotary_embedding import _apply_rotary_emb
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
|
||||||
|
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
|
||||||
|
torch.manual_seed(0)
|
||||||
|
# constants
|
||||||
|
kv_lora_rank = 512
|
||||||
|
qk_head_dim = 192
|
||||||
|
qk_nope_head_dim = 128
|
||||||
|
qk_rope_head_dim = 64
|
||||||
|
rotary_dim = qk_rope_head_dim
|
||||||
|
num_heads = 22
|
||||||
|
q_lora_rank = 1536
|
||||||
|
hidden_size = 7168
|
||||||
|
B = 1
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def layernorm(x, weight, variance_epsilon=1e-6, residual=None):
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||||
|
return (x * weight).to(orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rotary_emb(q_pe, k_pe, pos, cos_sin_cache):
|
||||||
|
orig_dtype = q_pe.dtype
|
||||||
|
q_pe = q_pe.float()
|
||||||
|
k_pe = k_pe.float()
|
||||||
|
cos_sin_cache = cos_sin_cache.float()
|
||||||
|
|
||||||
|
query_rot = q_pe[..., :rotary_dim]
|
||||||
|
key_rot = k_pe[..., :rotary_dim]
|
||||||
|
cos_sin = cos_sin_cache[pos]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
query_rot = _apply_rotary_emb(query_rot, cos, sin, False)
|
||||||
|
key_rot = _apply_rotary_emb(key_rot, cos, sin, False)
|
||||||
|
return query_rot.to(orig_dtype), key_rot.to(orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def native_torch(
|
||||||
|
q_input,
|
||||||
|
hidden_states,
|
||||||
|
q_a_proj_weight,
|
||||||
|
norm_weight1,
|
||||||
|
q_b_proj_weight,
|
||||||
|
w_kc,
|
||||||
|
kv_a_proj_weight,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
):
|
||||||
|
|
||||||
|
q = torch.matmul(hidden_states, q_a_proj_weight.t())
|
||||||
|
q = layernorm(q, norm_weight1)
|
||||||
|
q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim)
|
||||||
|
|
||||||
|
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||||
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
||||||
|
|
||||||
|
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||||
|
latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t())
|
||||||
|
v_input = latent_cache[..., :kv_lora_rank]
|
||||||
|
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
||||||
|
k_input = latent_cache.unsqueeze(1)
|
||||||
|
k_input[..., :kv_lora_rank] = v_input
|
||||||
|
k_pe = k_input[..., kv_lora_rank:]
|
||||||
|
|
||||||
|
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
||||||
|
q_input[..., kv_lora_rank:] = q_pe
|
||||||
|
k_input[..., kv_lora_rank:] = k_pe
|
||||||
|
|
||||||
|
return q_input, k_input, v_input
|
||||||
|
|
||||||
|
|
||||||
|
def native_torch_int8(
|
||||||
|
q_input,
|
||||||
|
hidden_states,
|
||||||
|
w1_q,
|
||||||
|
w1_s,
|
||||||
|
norm_weight1,
|
||||||
|
w2_q,
|
||||||
|
w2_s,
|
||||||
|
w_kc,
|
||||||
|
w3_q,
|
||||||
|
w3_s,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
):
|
||||||
|
|
||||||
|
a_q, a_s = per_token_quant_int8(hidden_states)
|
||||||
|
q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16)
|
||||||
|
q = layernorm(q, norm_weight1)
|
||||||
|
|
||||||
|
a_q, a_s = per_token_quant_int8(q)
|
||||||
|
q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view(
|
||||||
|
-1, num_heads, qk_head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||||
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
||||||
|
|
||||||
|
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||||
|
a_q, a_s = per_token_quant_int8(hidden_states)
|
||||||
|
latent_cache = native_w8a8_per_token_matmul(
|
||||||
|
a_q, w3_q, a_s, w3_s, None, torch.bfloat16
|
||||||
|
)
|
||||||
|
v_input = latent_cache[..., :kv_lora_rank]
|
||||||
|
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
||||||
|
k_input = latent_cache.unsqueeze(1)
|
||||||
|
k_input[..., :kv_lora_rank] = v_input
|
||||||
|
k_pe = k_input[..., kv_lora_rank:]
|
||||||
|
|
||||||
|
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
||||||
|
q_input[..., kv_lora_rank:] = q_pe
|
||||||
|
k_input[..., kv_lora_rank:] = k_pe
|
||||||
|
|
||||||
|
return q_input, k_input, v_input
|
||||||
|
|
||||||
|
|
||||||
|
class TestQKVProjWithROPE(CustomTestCase):
|
||||||
|
def test_bf16_qkv_proj_with_rope(self):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||||
|
q_input = torch.empty(
|
||||||
|
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||||
|
)
|
||||||
|
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||||
|
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||||
|
q_b_proj_weight = (
|
||||||
|
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||||
|
)
|
||||||
|
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||||
|
kv_a_proj_weight = (
|
||||||
|
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||||
|
)
|
||||||
|
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||||
|
pos = torch.randint(10, 100, (B,))
|
||||||
|
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||||
|
q_ref, k_ref, v_ref = native_torch(
|
||||||
|
q_input,
|
||||||
|
hidden_states,
|
||||||
|
q_a_proj_weight,
|
||||||
|
norm_weight1,
|
||||||
|
q_b_proj_weight,
|
||||||
|
w_kc.transpose(1, 2),
|
||||||
|
kv_a_proj_weight,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
)
|
||||||
|
qa_packed = convert_weight_packed(q_a_proj_weight)
|
||||||
|
qb_packed = convert_weight_packed(q_b_proj_weight)
|
||||||
|
kva_packed = convert_weight_packed(kv_a_proj_weight)
|
||||||
|
wkc_packed = convert_weight_packed(w_kc)
|
||||||
|
|
||||||
|
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||||
|
hidden_states,
|
||||||
|
qa_packed,
|
||||||
|
qb_packed,
|
||||||
|
kva_packed,
|
||||||
|
wkc_packed,
|
||||||
|
norm_weight1,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
atol = rtol = precision[q_ref.dtype]
|
||||||
|
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
def test_int8_qkv_proj_with_rope(self):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||||
|
q_input = torch.empty(
|
||||||
|
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||||
|
)
|
||||||
|
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||||
|
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||||
|
q_b_proj_weight = (
|
||||||
|
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||||
|
)
|
||||||
|
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||||
|
kv_a_proj_weight = (
|
||||||
|
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||||
|
)
|
||||||
|
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||||
|
pos = torch.randint(10, 100, (B,))
|
||||||
|
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||||
|
|
||||||
|
w1_q, w1_s = per_token_quant_int8(q_a_proj_weight)
|
||||||
|
w2_q, w2_s = per_token_quant_int8(q_b_proj_weight)
|
||||||
|
w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight)
|
||||||
|
q_ref, k_ref, v_ref = native_torch_int8(
|
||||||
|
q_input,
|
||||||
|
hidden_states,
|
||||||
|
w1_q,
|
||||||
|
w1_s,
|
||||||
|
norm_weight1,
|
||||||
|
w2_q,
|
||||||
|
w2_s,
|
||||||
|
w_kc.transpose(1, 2),
|
||||||
|
w3_q,
|
||||||
|
w3_s,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
)
|
||||||
|
w1_q_packed = convert_weight_packed(w1_q)
|
||||||
|
w2_q_packed = convert_weight_packed(w2_q)
|
||||||
|
w3_q_packed = convert_weight_packed(w3_q)
|
||||||
|
wkc_packed = convert_weight_packed(w_kc)
|
||||||
|
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||||
|
hidden_states,
|
||||||
|
w1_q_packed,
|
||||||
|
w2_q_packed,
|
||||||
|
w3_q_packed,
|
||||||
|
wkc_packed,
|
||||||
|
norm_weight1,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
w1_s,
|
||||||
|
w2_s,
|
||||||
|
w3_s,
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
atol = rtol = precision[q_ref.dtype]
|
||||||
|
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
def test_fp8_qkv_proj_with_rope(self):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||||
|
q_input = torch.empty(
|
||||||
|
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||||
|
)
|
||||||
|
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||||
|
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||||
|
q_b_proj_weight = (
|
||||||
|
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||||
|
)
|
||||||
|
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||||
|
kv_a_proj_weight = (
|
||||||
|
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||||
|
)
|
||||||
|
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||||
|
pos = torch.randint(10, 100, (B,))
|
||||||
|
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||||
|
|
||||||
|
scale_block_size_N = 128
|
||||||
|
scale_block_size_K = 128
|
||||||
|
fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = (
|
||||||
|
convert_weight(
|
||||||
|
q_a_proj_weight,
|
||||||
|
[scale_block_size_N, scale_block_size_K],
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = (
|
||||||
|
convert_weight(
|
||||||
|
q_b_proj_weight,
|
||||||
|
[scale_block_size_N, scale_block_size_K],
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
fp8_kv_a_proj_with_mqa_weight,
|
||||||
|
kv_a_proj_with_mqa_weight_scale_inv,
|
||||||
|
kv_a_proj_with_mqa_weight_dq,
|
||||||
|
) = convert_weight(
|
||||||
|
kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16
|
||||||
|
)
|
||||||
|
q_ref, k_ref, v_ref = native_torch(
|
||||||
|
q_input,
|
||||||
|
hidden_states,
|
||||||
|
q_a_proj_weight_dq,
|
||||||
|
norm_weight1,
|
||||||
|
q_b_proj_weight_dq,
|
||||||
|
w_kc.transpose(1, 2),
|
||||||
|
kv_a_proj_with_mqa_weight_dq,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
)
|
||||||
|
fp8_q_a_proj_weight = convert_weight_packed(fp8_q_a_proj_weight)
|
||||||
|
fp8_q_b_proj_weight = convert_weight_packed(fp8_q_b_proj_weight)
|
||||||
|
fp8_kv_a_proj_with_mqa_weight = convert_weight_packed(
|
||||||
|
fp8_kv_a_proj_with_mqa_weight
|
||||||
|
)
|
||||||
|
w_kc = convert_weight_packed(w_kc)
|
||||||
|
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||||
|
hidden_states,
|
||||||
|
fp8_q_a_proj_weight,
|
||||||
|
fp8_q_b_proj_weight,
|
||||||
|
fp8_kv_a_proj_with_mqa_weight,
|
||||||
|
w_kc,
|
||||||
|
norm_weight1,
|
||||||
|
norm_weight2,
|
||||||
|
pos,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
q_a_proj_weight_scale_inv.float(),
|
||||||
|
q_b_proj_weight_scale_inv.float(),
|
||||||
|
kv_a_proj_with_mqa_weight_scale_inv.float(),
|
||||||
|
True,
|
||||||
|
[scale_block_size_N, scale_block_size_K],
|
||||||
|
)
|
||||||
|
atol = rtol = precision[q_ref.dtype]
|
||||||
|
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
|
||||||
|
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user