Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT (#6493)

This commit is contained in:
blzheng
2025-05-23 15:14:46 +08:00
committed by GitHub
parent 4685fbb888
commit 4ba1eea83f
5 changed files with 483 additions and 11 deletions

View File

@@ -165,10 +165,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni);
bool is_vnni,
std::optional<std::vector<int64_t>> block_size);
// shared memory init
void initialize(int64_t size, int64_t rank);
@@ -209,8 +211,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// decode
m.def(
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()");
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
// extend
@@ -265,8 +268,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def(
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
"q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
// shared expert