Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT (#6493)
This commit is contained in:
@@ -165,10 +165,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor> q_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
std::optional<at::Tensor> kv_a_proj_scale,
|
||||
bool is_vnni);
|
||||
bool is_vnni,
|
||||
std::optional<std::vector<int64_t>> block_size);
|
||||
|
||||
// shared memory init
|
||||
void initialize(int64_t size, int64_t rank);
|
||||
@@ -209,8 +211,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
// decode
|
||||
m.def(
|
||||
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
|
||||
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
|
||||
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
|
||||
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
|
||||
"float logit_cap) -> ()");
|
||||
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
||||
|
||||
// extend
|
||||
@@ -265,8 +268,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def(
|
||||
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
|
||||
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
|
||||
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
|
||||
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
|
||||
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
|
||||
"q_b_proj_scale, Tensor? "
|
||||
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
|
||||
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
|
||||
|
||||
// shared expert
|
||||
|
||||
Reference in New Issue
Block a user