CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -162,6 +162,7 @@ void segment_gemm_kernel_impl(
|
||||
const at::Float8_e4m3fn* __restrict__ B1,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
@@ -185,10 +186,9 @@ void segment_gemm_kernel_impl(
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
// for brgemm when mat2 is float8_e4m3
|
||||
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
@@ -209,7 +209,7 @@ void segment_gemm_kernel_impl(
|
||||
/* A */ A + mb_start * K,
|
||||
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ C + mb_start * ldc + local_nb_start,
|
||||
/* Btmp*/ Btmp,
|
||||
/* Btmp*/ Btmp + tid * BLOCK_N * K,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ mb_size,
|
||||
@@ -541,6 +541,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
|
||||
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
|
||||
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
|
||||
|
||||
const int BLOCK_N = block_size_n();
|
||||
const int num_threads = at::get_num_threads();
|
||||
auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options);
|
||||
segment_gemm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
k_input.data_ptr<scalar_t>(),
|
||||
@@ -549,6 +553,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
|
||||
q_a_proj_s.data_ptr<float>(),
|
||||
kv_a_proj_s.data_ptr<float>(),
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
@@ -624,3 +629,74 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
|
||||
return std::make_tuple(q_input, k_input, v_input);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& qkv_a_proj_weight,
|
||||
at::Tensor& q_b_proj_weight,
|
||||
at::Tensor& w_kc,
|
||||
at::Tensor& q_a_layernorm_weight,
|
||||
at::Tensor& kv_a_layernorm_weight,
|
||||
at::Tensor& positions,
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor> qkv_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
bool is_vnni,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
int64_t q_lora_rank,
|
||||
int64_t kv_lora_rank,
|
||||
int64_t qk_rope_head_dim) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::qkv_proj_with_rope_fused_weight",
|
||||
std::vector<c10::IValue>({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc}));
|
||||
|
||||
int64_t hidden_size = hidden_states.size(1);
|
||||
CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim);
|
||||
CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
|
||||
|
||||
std::vector<at::Tensor> weight_chunks =
|
||||
at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
|
||||
at::Tensor q_a_proj_weight = weight_chunks[0];
|
||||
at::Tensor kv_a_proj_weight = weight_chunks[1];
|
||||
at::Tensor q_a_proj_s;
|
||||
at::Tensor kv_a_proj_s;
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8.");
|
||||
std::vector<at::Tensor> scale_chunks =
|
||||
at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
|
||||
q_a_proj_s = scale_chunks[0];
|
||||
kv_a_proj_s = scale_chunks[1];
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16.");
|
||||
int64_t block_size_N = block_size.value()[0];
|
||||
int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N);
|
||||
int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N);
|
||||
std::vector<at::Tensor> scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0);
|
||||
q_a_proj_s = scale_chunks[0];
|
||||
kv_a_proj_s = scale_chunks[1];
|
||||
}
|
||||
|
||||
return qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
q_b_proj_weight,
|
||||
kv_a_proj_weight,
|
||||
w_kc,
|
||||
q_a_layernorm_weight,
|
||||
kv_a_layernorm_weight,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
use_int8_w8a8,
|
||||
use_fp8_w8a16,
|
||||
q_a_proj_s,
|
||||
q_b_proj_scale,
|
||||
kv_a_proj_s,
|
||||
is_vnni,
|
||||
block_size);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user