Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT (#6493)

This commit is contained in:
blzheng
2025-05-23 15:14:46 +08:00
committed by GitHub
parent 4685fbb888
commit 4ba1eea83f
5 changed files with 483 additions and 11 deletions

View File

@@ -152,6 +152,85 @@ void segment_gemm_kernel_impl(
});
}
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B0,
const at::Float8_e4m3fn* __restrict__ B1,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K,
int64_t block_size_N,
int64_t block_size_K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
// for brgemm when mat2 is float8_e4m3
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1;
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
int64_t new_nb = nb < NB0 ? nb : nb - NB0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Btmp*/ Btmp,
/* Ctmp*/ Ctmp,
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
template <typename scalar_t>
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
@@ -321,6 +400,15 @@ extern at::Tensor int8_scaled_mm_with_quant(
extern void
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
extern at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// NB: shapes in DeepDeek R1
//
// hidden_states : [num_seqs, hidden_size] [1, 7168]
@@ -343,10 +431,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni) {
bool is_vnni,
std::optional<std::vector<int64_t>> block_size) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope",
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
@@ -394,7 +484,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
}
if (use_fp8_w8a16) {
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16.");
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16.");
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16.");
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16.");
}
// outputs and temp buffer
const auto options = hidden_states.options();
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
@@ -436,6 +532,29 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size);
} else if (use_fp8_w8a16) {
int64_t block_size_N = block_size.value()[0];
int64_t block_size_K = block_size.value()[1];
auto q_a_proj_s = q_a_proj_scale.value();
auto kv_a_proj_s = kv_a_proj_scale.value();
CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N));
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
hidden_states.data_ptr<scalar_t>(),
q_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size,
block_size_N,
block_size_K);
} else {
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
@@ -469,6 +588,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
std::optional<at::Tensor> bias;
if (use_int8_w8a8) {
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
} else if (use_fp8_w8a16) {
qb = fp8_scaled_mm_cpu(
qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni);
} else {
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
}

View File

@@ -165,10 +165,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni);
bool is_vnni,
std::optional<std::vector<int64_t>> block_size);
// shared memory init
void initialize(int64_t size, int64_t rank);
@@ -209,8 +211,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// decode
m.def(
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()");
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
// extend
@@ -265,8 +268,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def(
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
"q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
// shared expert