CPU: map changes from developing branch in sgl-kernel (#6833)

Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
YanbingJiang
2025-06-10 16:08:15 +08:00
committed by GitHub
parent 81372f3bef
commit fcde67b016
20 changed files with 1321 additions and 321 deletions

View File

@@ -162,6 +162,7 @@ void segment_gemm_kernel_impl(
const at::Float8_e4m3fn* __restrict__ B1,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
scalar_t* __restrict__ Btmp,
int64_t M,
int64_t N0,
int64_t N1,
@@ -185,10 +186,9 @@ void segment_gemm_kernel_impl(
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
int tid = at::get_thread_num();
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
// for brgemm when mat2 is float8_e4m3
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
@@ -209,7 +209,7 @@ void segment_gemm_kernel_impl(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Btmp*/ Btmp,
/* Btmp*/ Btmp + tid * BLOCK_N * K,
/* Ctmp*/ Ctmp,
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
/* M */ mb_size,
@@ -541,6 +541,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
const int BLOCK_N = block_size_n();
const int num_threads = at::get_num_threads();
auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options);
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
@@ -549,6 +553,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(),
buffer.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
@@ -624,3 +629,74 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
return std::make_tuple(q_input, k_input, v_input);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
at::Tensor& hidden_states,
at::Tensor& qkv_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> qkv_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size,
int64_t q_lora_rank,
int64_t kv_lora_rank,
int64_t qk_rope_head_dim) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope_fused_weight",
std::vector<c10::IValue>({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc}));
int64_t hidden_size = hidden_states.size(1);
CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim);
CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
std::vector<at::Tensor> weight_chunks =
at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
at::Tensor q_a_proj_weight = weight_chunks[0];
at::Tensor kv_a_proj_weight = weight_chunks[1];
at::Tensor q_a_proj_s;
at::Tensor kv_a_proj_s;
if (use_int8_w8a8) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8.");
std::vector<at::Tensor> scale_chunks =
at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
if (use_fp8_w8a16) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16.");
int64_t block_size_N = block_size.value()[0];
int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N);
int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N);
std::vector<at::Tensor> scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
return qkv_proj_with_rope(
hidden_states,
q_a_proj_weight,
q_b_proj_weight,
kv_a_proj_weight,
w_kc,
q_a_layernorm_weight,
kv_a_layernorm_weight,
positions,
cos_sin_cache,
eps,
use_int8_w8a8,
use_fp8_w8a16,
q_a_proj_s,
q_b_proj_scale,
kv_a_proj_s,
is_vnni,
block_size);
}