Add CPU optimized kernels for topk and rope fusions (#6456)
This commit is contained in:
@@ -4,6 +4,67 @@
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// NB: avoid using `at::vec::map<>` on bfloat16 or half
|
// NB: avoid using `at::vec::map<>` on bfloat16 or half
|
||||||
|
// Llama4TextL2Norm
|
||||||
|
template <typename scalar_t>
|
||||||
|
void l2norm_kernel_impl(
|
||||||
|
scalar_t* __restrict__ output,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
int64_t batch_size,
|
||||||
|
int64_t hidden_size,
|
||||||
|
float eps = 1e-5) {
|
||||||
|
using bVec = at::vec::Vectorized<scalar_t>;
|
||||||
|
using fVec = at::vec::Vectorized<float>;
|
||||||
|
|
||||||
|
constexpr int kVecSize = bVec::size();
|
||||||
|
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
||||||
|
for (int64_t i = begin; i < end; ++i) {
|
||||||
|
// local ptrs
|
||||||
|
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
||||||
|
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
|
||||||
|
|
||||||
|
fVec sum_fvec = fVec(float(0));
|
||||||
|
float sum_val = float(0);
|
||||||
|
|
||||||
|
int64_t d;
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||||
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||||
|
fVec x_fvec0, x_fvec1;
|
||||||
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||||
|
|
||||||
|
sum_fvec += x_fvec0 * x_fvec0;
|
||||||
|
sum_fvec += x_fvec1 * x_fvec1;
|
||||||
|
}
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (; d < hidden_size; ++d) {
|
||||||
|
float x_val = static_cast<float>(input_ptr[d]);
|
||||||
|
sum_val += x_val * x_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
sum_val += vec_reduce_sum(sum_fvec);
|
||||||
|
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
||||||
|
const fVec scale_fvec = fVec(rsqrt_var);
|
||||||
|
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||||
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||||
|
fVec x_fvec0, x_fvec1;
|
||||||
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||||
|
|
||||||
|
x_fvec0 = x_fvec0 * scale_fvec;
|
||||||
|
x_fvec1 = x_fvec1 * scale_fvec;
|
||||||
|
|
||||||
|
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||||
|
out_bvec.store(out_ptr + d);
|
||||||
|
}
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (; d < hidden_size; ++d) {
|
||||||
|
float x_val = static_cast<float>(input_ptr[d]);
|
||||||
|
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rmsnorm_kernel_impl(
|
void rmsnorm_kernel_impl(
|
||||||
scalar_t* __restrict__ output,
|
scalar_t* __restrict__ output,
|
||||||
@@ -160,6 +221,22 @@ void fused_add_rmsnorm_kernel_impl(
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
// input : {batch_size, hidden_size}
|
||||||
|
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
|
||||||
|
RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));
|
||||||
|
|
||||||
|
CHECK_INPUT(input);
|
||||||
|
CHECK_DIM(2, input);
|
||||||
|
int64_t batch_size = input.size(0);
|
||||||
|
int64_t hidden_size = input.size(1);
|
||||||
|
at::Tensor output = at::empty_like(input);
|
||||||
|
|
||||||
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
|
||||||
|
l2norm_kernel_impl<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), batch_size, hidden_size, eps);
|
||||||
|
});
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
// input : {batch_size, hidden_size}
|
// input : {batch_size, hidden_size}
|
||||||
// weight: {hidden_size}
|
// weight: {hidden_size}
|
||||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||||
|
|||||||
@@ -4,126 +4,343 @@
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rope_kernel_impl(
|
void rotary_embedding_3D_kernel_impl(
|
||||||
scalar_t* __restrict__ q_pe_out,
|
scalar_t* __restrict__ query_out,
|
||||||
scalar_t* __restrict__ k_pe_out,
|
scalar_t* __restrict__ key_out,
|
||||||
int64_t* __restrict__ t_pos,
|
int64_t* __restrict__ positions,
|
||||||
scalar_t* __restrict__ q_pe,
|
scalar_t* __restrict__ query,
|
||||||
scalar_t* __restrict__ k_pe,
|
scalar_t* __restrict__ key,
|
||||||
scalar_t* __restrict__ t_emb_pos,
|
scalar_t* __restrict__ cos_sin_cache,
|
||||||
int64_t seq_len,
|
int64_t num_tokens,
|
||||||
int64_t num_head,
|
int64_t num_heads,
|
||||||
|
int64_t num_kv_heads,
|
||||||
|
int64_t head_size,
|
||||||
int64_t rotary_dim,
|
int64_t rotary_dim,
|
||||||
int64_t HR,
|
int64_t query_stride_s,
|
||||||
int64_t q_pe_stride_s,
|
int64_t query_out_stride_s,
|
||||||
int64_t out_stride_qs,
|
int64_t key_out_stride_s,
|
||||||
int64_t out_stride_ks,
|
int64_t key_stride_s,
|
||||||
int64_t HK,
|
int64_t query_stride_h,
|
||||||
int64_t k_pe_stride_s,
|
int64_t query_out_stride_h) {
|
||||||
int64_t q_pe_stride_n,
|
int64_t HR = rotary_dim;
|
||||||
int64_t out_stride_qn) {
|
int64_t HK = rotary_dim;
|
||||||
int64_t COFF = HR / 2;
|
int64_t COFF = HR / 2;
|
||||||
at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||||
int64_t seq{0}, head_id{0};
|
int64_t seq{0}, head_id{0};
|
||||||
data_index_init(begin, seq, seq_len, head_id, num_head);
|
data_index_init(begin, seq, num_tokens, head_id, num_heads);
|
||||||
for (int64_t i = begin; i < end; ++i) {
|
for (int64_t i = begin; i < end; ++i) {
|
||||||
int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n;
|
int64_t in_offset_q = seq * query_stride_s + head_id * query_stride_h;
|
||||||
int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn;
|
int64_t out_offset_q = seq * query_out_stride_s + head_id * query_out_stride_h;
|
||||||
int64_t out_offset_k = seq * out_stride_ks;
|
int64_t out_offset_k = seq * key_out_stride_s;
|
||||||
int64_t p = 0;
|
int64_t p = 0;
|
||||||
scalar_t* sin_start = nullptr;
|
scalar_t* sin_start = nullptr;
|
||||||
scalar_t* cos_start = nullptr;
|
scalar_t* cos_start = nullptr;
|
||||||
// step 0) get the rotary position embedding for the current position
|
// step 0) get the rotary position embedding for the current position
|
||||||
p = t_pos[seq];
|
p = positions[seq];
|
||||||
sin_start = t_emb_pos + p * HR + COFF;
|
sin_start = cos_sin_cache + p * HR + COFF;
|
||||||
cos_start = t_emb_pos + p * HR;
|
cos_start = cos_sin_cache + p * HR;
|
||||||
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
|
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
|
||||||
// head of query/key
|
// head of query/key
|
||||||
for (int64_t h = 0; h < rotary_dim; h += 2) {
|
for (int64_t h = 0; h < rotary_dim; h += 2) {
|
||||||
scalar_t cos = cos_start[h >> 1];
|
scalar_t cos = cos_start[h >> 1];
|
||||||
scalar_t sin = sin_start[h >> 1];
|
scalar_t sin = sin_start[h >> 1];
|
||||||
scalar_t in1 = q_pe[in_offset_q + h];
|
scalar_t in1 = query[in_offset_q + h];
|
||||||
scalar_t in2 = q_pe[in_offset_q + h + 1];
|
scalar_t in2 = query[in_offset_q + h + 1];
|
||||||
scalar_t out1 = in1 * cos - in2 * sin;
|
scalar_t out1 = in1 * cos - in2 * sin;
|
||||||
scalar_t out2 = in2 * cos + in1 * sin;
|
scalar_t out2 = in2 * cos + in1 * sin;
|
||||||
q_pe_out[out_offset_q + h] = out1;
|
query_out[out_offset_q + h] = out1;
|
||||||
q_pe_out[out_offset_q + h + 1] = out2;
|
query_out[out_offset_q + h + 1] = out2;
|
||||||
}
|
}
|
||||||
for (int64_t h = 0; h < HK; h += 2) {
|
for (int64_t h = 0; h < HK; h += 2) {
|
||||||
scalar_t cos = cos_start[h >> 1];
|
scalar_t cos = cos_start[h >> 1];
|
||||||
scalar_t sin = sin_start[h >> 1];
|
scalar_t sin = sin_start[h >> 1];
|
||||||
int64_t k_pe_offset = seq * k_pe_stride_s;
|
int64_t k_pe_offset = seq * key_stride_s;
|
||||||
scalar_t in1_k = k_pe[k_pe_offset + h];
|
scalar_t in1_k = key[k_pe_offset + h];
|
||||||
scalar_t in2_k = k_pe[k_pe_offset + h + 1];
|
scalar_t in2_k = key[k_pe_offset + h + 1];
|
||||||
scalar_t out1_k = in1_k * cos - in2_k * sin;
|
scalar_t out1_k = in1_k * cos - in2_k * sin;
|
||||||
scalar_t out2_k = in2_k * cos + in1_k * sin;
|
scalar_t out2_k = in2_k * cos + in1_k * sin;
|
||||||
k_pe_out[out_offset_k + h] = out1_k;
|
key_out[out_offset_k + h] = out1_k;
|
||||||
k_pe_out[out_offset_k + h + 1] = out2_k;
|
key_out[out_offset_k + h + 1] = out2_k;
|
||||||
}
|
}
|
||||||
// move to the next index
|
// move to the next index
|
||||||
data_index_step(seq, seq_len, head_id, num_head);
|
data_index_step(seq, num_tokens, head_id, num_heads);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void rotary_embedding_neox_2D_kernel_impl(
|
||||||
|
int64_t* __restrict__ positions,
|
||||||
|
scalar_t* __restrict__ query,
|
||||||
|
scalar_t* __restrict__ key,
|
||||||
|
scalar_t* __restrict__ cos_sin_cache,
|
||||||
|
int64_t rotary_dim,
|
||||||
|
int64_t query_stride_s,
|
||||||
|
int64_t key_stride_s,
|
||||||
|
int64_t num_heads,
|
||||||
|
int64_t num_kv_heads,
|
||||||
|
int64_t head_size,
|
||||||
|
int64_t num_tokens) {
|
||||||
|
using bVec = at::vec::Vectorized<scalar_t>;
|
||||||
|
using fVec = at::vec::Vectorized<float>;
|
||||||
|
constexpr int64_t bVecSize = bVec::size();
|
||||||
|
|
||||||
|
int64_t embed_dim = rotary_dim / 2;
|
||||||
|
bool flag = (embed_dim % bVecSize == 0);
|
||||||
|
int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize;
|
||||||
|
|
||||||
|
auto compute_loop = [&](int64_t token_head, scalar_t* cache_ptr, scalar_t* qk) {
|
||||||
|
int64_t j = 0;
|
||||||
|
for (; j < loop_upper; j += bVecSize) {
|
||||||
|
int64_t rot_offset = j;
|
||||||
|
int64_t x_index = rot_offset;
|
||||||
|
int64_t y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
|
int64_t out_x = token_head + x_index;
|
||||||
|
int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
bVec _cos = bVec::loadu(cache_ptr + x_index);
|
||||||
|
bVec _sin = bVec::loadu(cache_ptr + y_index);
|
||||||
|
|
||||||
|
bVec _q_x = bVec::loadu(qk + out_x);
|
||||||
|
bVec _q_y = bVec::loadu(qk + out_y);
|
||||||
|
fVec _cos_0, _cos_1;
|
||||||
|
std::tie(_cos_0, _cos_1) = at::vec::convert_to_float(_cos);
|
||||||
|
fVec _sin_0, _sin_1;
|
||||||
|
std::tie(_sin_0, _sin_1) = at::vec::convert_to_float(_sin);
|
||||||
|
fVec _q_x_0, _q_x_1;
|
||||||
|
std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x);
|
||||||
|
fVec _q_y_0, _q_y_1;
|
||||||
|
std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y);
|
||||||
|
|
||||||
|
auto out1_0 = _q_x_0 * _cos_0 - _q_y_0 * _sin_0;
|
||||||
|
auto out1_1 = _q_x_1 * _cos_1 - _q_y_1 * _sin_1;
|
||||||
|
auto out1 = convert_from_float_ext<scalar_t>(out1_0, out1_1);
|
||||||
|
out1.store(qk + out_x);
|
||||||
|
|
||||||
|
auto out2_0 = _q_y_0 * _cos_0 + _q_x_0 * _sin_0;
|
||||||
|
auto out2_1 = _q_y_1 * _cos_1 + _q_x_1 * _sin_1;
|
||||||
|
auto out2 = convert_from_float_ext<scalar_t>(out2_0, out2_1);
|
||||||
|
out2.store(qk + out_y);
|
||||||
|
}
|
||||||
|
if (!flag) {
|
||||||
|
for (; j < embed_dim; ++j) {
|
||||||
|
int64_t x_index = j;
|
||||||
|
int64_t y_index = embed_dim + j;
|
||||||
|
|
||||||
|
int64_t out_x = token_head + x_index;
|
||||||
|
int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
float _cos = cache_ptr[x_index];
|
||||||
|
float _sin = cache_ptr[y_index];
|
||||||
|
|
||||||
|
float _q_x = qk[out_x];
|
||||||
|
float _q_y = qk[out_y];
|
||||||
|
|
||||||
|
qk[out_x] = _q_x * _cos - _q_y * _sin;
|
||||||
|
qk[out_y] = _q_y * _cos + _q_x * _sin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
int64_t pos = positions[token_idx];
|
||||||
|
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < num_heads; ++i) {
|
||||||
|
int64_t head_idx = i;
|
||||||
|
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
|
||||||
|
compute_loop(token_head, cache_ptr, query);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < num_kv_heads; ++i) {
|
||||||
|
int64_t head_idx = i;
|
||||||
|
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
|
||||||
|
compute_loop(token_head, cache_ptr, key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void rotary_embedding_2D_kernel_impl(
|
||||||
|
int64_t* __restrict__ positions,
|
||||||
|
scalar_t* __restrict__ query,
|
||||||
|
scalar_t* __restrict__ key,
|
||||||
|
scalar_t* __restrict__ cos_sin_cache,
|
||||||
|
int64_t rotary_dim,
|
||||||
|
int64_t query_stride_s,
|
||||||
|
int64_t key_stride_s,
|
||||||
|
int64_t num_heads,
|
||||||
|
int64_t num_kv_heads,
|
||||||
|
int64_t head_size,
|
||||||
|
int64_t num_tokens) {
|
||||||
|
int64_t embed_dim = rotary_dim / 2;
|
||||||
|
|
||||||
|
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||||
|
int64_t token_idx = {0}, i = {0};
|
||||||
|
data_index_init(begin, token_idx, num_tokens, i, num_heads);
|
||||||
|
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||||
|
int64_t pos = positions[token_idx];
|
||||||
|
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||||
|
scalar_t* cos_cache_ptr = cache_ptr;
|
||||||
|
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||||
|
int64_t head_idx = i;
|
||||||
|
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
|
||||||
|
scalar_t* head_query = token_head + query;
|
||||||
|
for (int64_t j = 0; j < embed_dim; j += 1) {
|
||||||
|
int64_t rot_offset = j;
|
||||||
|
int64_t x_index = 2 * rot_offset;
|
||||||
|
int64_t y_index = 2 * rot_offset + 1;
|
||||||
|
|
||||||
|
float cos = cos_cache_ptr[rot_offset];
|
||||||
|
float sin = sin_cache_ptr[rot_offset];
|
||||||
|
|
||||||
|
float x = head_query[x_index];
|
||||||
|
float y = head_query[y_index];
|
||||||
|
|
||||||
|
head_query[x_index] = x * cos - y * sin;
|
||||||
|
head_query[y_index] = y * cos + x * sin;
|
||||||
|
}
|
||||||
|
data_index_step(token_idx, num_tokens, i, num_heads);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
at::parallel_for(0, num_tokens * num_kv_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||||
|
int64_t token_idx{0}, i = {0};
|
||||||
|
data_index_init(begin, token_idx, num_tokens, i, num_kv_heads);
|
||||||
|
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||||
|
int64_t pos = positions[token_idx];
|
||||||
|
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||||
|
scalar_t* cos_cache_ptr = cache_ptr;
|
||||||
|
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||||
|
int64_t head_idx = i;
|
||||||
|
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
|
||||||
|
scalar_t* head_key = key + token_head;
|
||||||
|
for (int64_t j = 0; j < embed_dim; j += 1) {
|
||||||
|
int64_t rot_offset = j;
|
||||||
|
int64_t x_index = 2 * rot_offset;
|
||||||
|
int64_t y_index = 2 * rot_offset + 1;
|
||||||
|
|
||||||
|
float cos = cos_cache_ptr[rot_offset];
|
||||||
|
float sin = sin_cache_ptr[rot_offset];
|
||||||
|
|
||||||
|
float x = head_key[x_index];
|
||||||
|
float y = head_key[y_index];
|
||||||
|
|
||||||
|
head_key[x_index] = x * cos - y * sin;
|
||||||
|
head_key[y_index] = y * cos + x * sin;
|
||||||
|
}
|
||||||
|
data_index_step(token_idx, num_tokens, i, num_kv_heads);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) {
|
at::Tensor& positions,
|
||||||
RECORD_FUNCTION(
|
at::Tensor& query,
|
||||||
"sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos}));
|
at::Tensor& key,
|
||||||
CHECK_INPUT(t_pos);
|
int64_t head_size,
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe);
|
at::Tensor& cos_sin_cache,
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe);
|
bool is_neox) {
|
||||||
CHECK_INPUT(t_emb_pos);
|
RECORD_FUNCTION("sgl-kernel::rotary_embedding_cpu", std::vector<c10::IValue>({query, key}));
|
||||||
CHECK_DIM(1, t_pos);
|
CHECK_DIM(1, positions);
|
||||||
CHECK_DIM(3, q_pe);
|
const auto input_dim = query.dim();
|
||||||
CHECK_DIM(3, k_pe);
|
const auto input_dtype = query.scalar_type();
|
||||||
CHECK_DIM(2, t_emb_pos);
|
TORCH_CHECK(
|
||||||
|
input_dim == 2 || input_dim == 3,
|
||||||
|
" Query/Key must be 2D [num_tokens, num_heads*head_size] or 3D [num_tokens, num_heads, head_size] tensor");
|
||||||
|
CHECK_DIM(2, cos_sin_cache);
|
||||||
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
|
||||||
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
|
||||||
|
|
||||||
int64_t seq_len = q_pe.size(0);
|
int64_t rotary_dim = cos_sin_cache.size(1);
|
||||||
int64_t num_head = q_pe.size(1);
|
if (input_dim == 3) {
|
||||||
int64_t rotary_dim = q_pe.size(2);
|
// TODO: add support for head_dim != rotary_dim case when input_dim=3
|
||||||
int64_t HK = k_pe.size(2);
|
CHECK_EQ(query.size(-1), rotary_dim);
|
||||||
int64_t HR = t_emb_pos.size(1);
|
// TODO: add support for kv_head != 1
|
||||||
CHECK_EQ(HR, rotary_dim);
|
CHECK_EQ(key.size(1), 1);
|
||||||
CHECK_EQ(k_pe.size(0), seq_len);
|
}
|
||||||
CHECK_EQ(k_pe.size(1), 1);
|
|
||||||
CHECK_EQ(t_pos.size(0), seq_len);
|
|
||||||
CHECK_EQ(HK, rotary_dim);
|
|
||||||
|
|
||||||
at::Tensor q_pe_out = at::empty_like(q_pe);
|
int64_t num_tokens = positions.numel();
|
||||||
at::Tensor k_pe_out = at::empty_like(k_pe);
|
CHECK_EQ(key.size(0), num_tokens);
|
||||||
int64_t q_pe_stride_s = q_pe.stride(0);
|
CHECK_EQ(query.size(0), num_tokens);
|
||||||
int64_t q_pe_stride_n = q_pe.stride(1);
|
|
||||||
int64_t k_pe_stride_s = k_pe.stride(0);
|
|
||||||
int64_t out_stride_qs = q_pe_out.stride(0);
|
|
||||||
int64_t out_stride_qn = q_pe_out.stride(1);
|
|
||||||
int64_t out_stride_ks = k_pe_out.stride(0);
|
|
||||||
|
|
||||||
const auto input_dtype = q_pe.scalar_type();
|
TORCH_CHECK(positions.scalar_type() == at::kLong, "expect positions to be int64, got ", positions.scalar_type());
|
||||||
TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type());
|
TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type");
|
||||||
TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type");
|
TORCH_CHECK(input_dtype == cos_sin_cache.scalar_type(), "query and cos_sin_cache must have the same data type");
|
||||||
TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type");
|
|
||||||
|
|
||||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] {
|
int64_t num_heads = input_dim == 2 ? query.size(-1) / head_size : query.size(1);
|
||||||
rope_kernel_impl<scalar_t>(
|
int64_t num_kv_heads = input_dim == 2 ? key.size(-1) / head_size : key.size(1);
|
||||||
q_pe_out.data_ptr<scalar_t>(),
|
int64_t key_stride_s = key.stride(0);
|
||||||
k_pe_out.data_ptr<scalar_t>(),
|
int64_t query_stride_s = query.stride(0);
|
||||||
t_pos.data_ptr<int64_t>(),
|
|
||||||
q_pe.data_ptr<scalar_t>(),
|
// input stride of num head dim is meaningful only when input dim = 3
|
||||||
k_pe.data_ptr<scalar_t>(),
|
int64_t query_stride_h = input_dim == 3 ? query.stride(1) : -1;
|
||||||
t_emb_pos.data_ptr<scalar_t>(),
|
at::Tensor query_out = at::empty_like(query);
|
||||||
seq_len,
|
at::Tensor key_out = at::empty_like(key);
|
||||||
num_head,
|
int64_t query_out_stride_s = query_out.stride(0);
|
||||||
rotary_dim,
|
int64_t key_out_stride_s = key_out.stride(0);
|
||||||
HR,
|
// output stride of num head dim is meaningful only when input dim = 3
|
||||||
q_pe_stride_s,
|
int64_t query_out_stride_h = input_dim == 3 ? query_out.stride(1) : -1;
|
||||||
out_stride_qs,
|
|
||||||
out_stride_ks,
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_embedding_cpu", [&] {
|
||||||
HK,
|
if (input_dim == 2) {
|
||||||
k_pe_stride_s,
|
if (is_neox) {
|
||||||
q_pe_stride_n,
|
rotary_embedding_neox_2D_kernel_impl<scalar_t>(
|
||||||
out_stride_qn);
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rotary_dim,
|
||||||
|
query_stride_s,
|
||||||
|
key_stride_s,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
num_tokens);
|
||||||
|
} else {
|
||||||
|
rotary_embedding_2D_kernel_impl<scalar_t>(
|
||||||
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rotary_dim,
|
||||||
|
query_stride_s,
|
||||||
|
key_stride_s,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
num_tokens);
|
||||||
|
}
|
||||||
|
query_out = query;
|
||||||
|
key_out = key;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
is_neox == false, " Query/Key with 3D [num_tokens, num_heads, head_size] does not support neox rope yet");
|
||||||
|
// TODO: add neox style support for rope impl with 3D inputs
|
||||||
|
rotary_embedding_3D_kernel_impl<scalar_t>(
|
||||||
|
query_out.data_ptr<scalar_t>(),
|
||||||
|
key_out.data_ptr<scalar_t>(),
|
||||||
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
num_tokens,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
query_stride_s,
|
||||||
|
query_out_stride_s,
|
||||||
|
key_out_stride_s,
|
||||||
|
key_stride_s,
|
||||||
|
query_stride_h,
|
||||||
|
query_out_stride_h);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
return std::make_tuple(q_pe_out, k_pe_out);
|
return std::make_tuple(query_out, key_out);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -157,6 +157,101 @@ inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int NUM_EXPERTS>
|
||||||
|
void topk_sigmoid_kernel_impl(
|
||||||
|
float* __restrict__ topk_weights,
|
||||||
|
int32_t* __restrict__ topk_ids,
|
||||||
|
const scalar_t* __restrict__ gating_output,
|
||||||
|
int64_t num_tokens,
|
||||||
|
int64_t topk,
|
||||||
|
bool renormalize) {
|
||||||
|
using Vec = at::vec::Vectorized<float>;
|
||||||
|
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||||
|
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||||
|
alignas(64) float scores[NUM_EXPERTS];
|
||||||
|
using elem_t = std::pair<float, int32_t>;
|
||||||
|
std::vector<elem_t> queue(num_experts_per_group);
|
||||||
|
|
||||||
|
for (int64_t i = begin; i < end; ++i) {
|
||||||
|
at::vec::convert<scalar_t, float>(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS);
|
||||||
|
|
||||||
|
float gmax = at::vec::reduce_all<float>(
|
||||||
|
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group);
|
||||||
|
|
||||||
|
// find position of first max,
|
||||||
|
// note that we may have multiple max values.
|
||||||
|
int first_max_idx = -1;
|
||||||
|
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||||
|
if (scores[e] == gmax) {
|
||||||
|
first_max_idx = e;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scalar sigmoid
|
||||||
|
topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax));
|
||||||
|
topk_ids[i] = first_max_idx;
|
||||||
|
|
||||||
|
if (renormalize) {
|
||||||
|
float sum = 0.f;
|
||||||
|
for (int64_t j = 0; j < topk; ++j) {
|
||||||
|
sum += topk_weights[i * topk + j];
|
||||||
|
}
|
||||||
|
float scale = 1.f / sum;
|
||||||
|
for (int64_t j = 0; j < topk; ++j) {
|
||||||
|
topk_weights[i * topk + j] *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int NUM_EXPERTS>
|
||||||
|
void topk_softmax_kernel_impl(
|
||||||
|
float* __restrict__ topk_weights,
|
||||||
|
int32_t* __restrict__ topk_ids,
|
||||||
|
const scalar_t* __restrict__ gating_output,
|
||||||
|
int64_t num_tokens,
|
||||||
|
int64_t topk,
|
||||||
|
bool renormalize) {
|
||||||
|
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||||
|
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||||
|
alignas(64) float scores[NUM_EXPERTS];
|
||||||
|
using elem_t = std::pair<float, int32_t>;
|
||||||
|
std::vector<elem_t> queue(num_experts_per_group);
|
||||||
|
|
||||||
|
for (int64_t i = begin; i < end; ++i) {
|
||||||
|
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||||
|
|
||||||
|
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||||
|
queue[e] = {scores[e], e};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::partial_sort(
|
||||||
|
queue.begin(),
|
||||||
|
queue.begin() + num_experts_per_group,
|
||||||
|
queue.end(),
|
||||||
|
[](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; });
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < topk; ++j) {
|
||||||
|
topk_weights[i * topk + j] = queue[j].first;
|
||||||
|
topk_ids[i * topk + j] = queue[j].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (renormalize) {
|
||||||
|
float sum = 0.f;
|
||||||
|
for (int64_t j = 0; j < topk; ++j) {
|
||||||
|
sum += topk_weights[i * topk + j];
|
||||||
|
}
|
||||||
|
float scale = 1.f / sum;
|
||||||
|
for (int64_t j = 0; j < topk; ++j) {
|
||||||
|
topk_weights[i * topk + j] *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
template <typename scalar_t, int SIZE>
|
template <typename scalar_t, int SIZE>
|
||||||
inline void
|
inline void
|
||||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
|
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
|
||||||
@@ -293,6 +388,24 @@ void biased_grouped_topk_kernel_impl(
|
|||||||
topk_group, \
|
topk_group, \
|
||||||
renormalize);
|
renormalize);
|
||||||
|
|
||||||
|
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
|
||||||
|
topk_sigmoid_kernel_impl<scalar_t, NE>( \
|
||||||
|
topk_weights.data_ptr<float>(), \
|
||||||
|
topk_ids.data_ptr<int32_t>(), \
|
||||||
|
gating_output.data_ptr<scalar_t>(), \
|
||||||
|
num_tokens, \
|
||||||
|
topk, \
|
||||||
|
renormalize);
|
||||||
|
|
||||||
|
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
|
||||||
|
topk_softmax_kernel_impl<scalar_t, NE>( \
|
||||||
|
topk_weights.data_ptr<float>(), \
|
||||||
|
topk_ids.data_ptr<int32_t>(), \
|
||||||
|
gating_output.data_ptr<scalar_t>(), \
|
||||||
|
num_tokens, \
|
||||||
|
topk, \
|
||||||
|
renormalize);
|
||||||
|
|
||||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||||
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
|
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
|
||||||
topk_weights.data_ptr<float>(), \
|
topk_weights.data_ptr<float>(), \
|
||||||
@@ -306,6 +419,114 @@ void biased_grouped_topk_kernel_impl(
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor>
|
||||||
|
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||||
|
RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||||
|
CHECK_INPUT(gating_output);
|
||||||
|
|
||||||
|
const auto st = hidden_states.scalar_type();
|
||||||
|
CHECK_EQ(gating_output.scalar_type(), st);
|
||||||
|
|
||||||
|
int64_t num_tokens = hidden_states.size(0);
|
||||||
|
int64_t num_experts = gating_output.size(1);
|
||||||
|
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||||
|
TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case");
|
||||||
|
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||||
|
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||||
|
|
||||||
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] {
|
||||||
|
switch (num_experts) {
|
||||||
|
case 1:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(2);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(4);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(8);
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(16);
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(32);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(64);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(128);
|
||||||
|
break;
|
||||||
|
case 160:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(160);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_TOPK_SIGMOID_KERNEL(256);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return std::make_tuple(topk_weights, topk_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor>
|
||||||
|
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||||
|
RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||||
|
CHECK_INPUT(gating_output);
|
||||||
|
|
||||||
|
const auto st = hidden_states.scalar_type();
|
||||||
|
CHECK_EQ(gating_output.scalar_type(), st);
|
||||||
|
|
||||||
|
int64_t num_tokens = hidden_states.size(0);
|
||||||
|
int64_t num_experts = gating_output.size(1);
|
||||||
|
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||||
|
|
||||||
|
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||||
|
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||||
|
|
||||||
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] {
|
||||||
|
switch (num_experts) {
|
||||||
|
case 1:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(2);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(4);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(8);
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(16);
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(32);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(64);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(128);
|
||||||
|
break;
|
||||||
|
case 160:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(160);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_TOPK_SOFTMAX_KERNEL(256);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return std::make_tuple(topk_weights, topk_ids);
|
||||||
|
}
|
||||||
|
|
||||||
// grouped topk for DeepSeek V2
|
// grouped topk for DeepSeek V2
|
||||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||||
at::Tensor& hidden_states,
|
at::Tensor& hidden_states,
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ limitations under the License.
|
|||||||
// silu_and_mul
|
// silu_and_mul
|
||||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||||
|
|
||||||
|
// l2norm
|
||||||
|
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
||||||
|
|
||||||
// rmsnorm
|
// rmsnorm
|
||||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
||||||
|
|
||||||
@@ -30,6 +33,11 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
|||||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
|
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
|
||||||
|
|
||||||
// topk
|
// topk
|
||||||
|
std::tuple<at::Tensor, at::Tensor>
|
||||||
|
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||||
|
std::tuple<at::Tensor, at::Tensor>
|
||||||
|
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||||
at::Tensor& hidden_states,
|
at::Tensor& hidden_states,
|
||||||
at::Tensor& gating_output,
|
at::Tensor& gating_output,
|
||||||
@@ -185,8 +193,13 @@ void shm_allreduce(
|
|||||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
|
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
|
||||||
|
|
||||||
// rope
|
// rope
|
||||||
std::tuple<at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
|
at::Tensor& positions,
|
||||||
|
at::Tensor& query,
|
||||||
|
at::Tensor& key,
|
||||||
|
int64_t head_size,
|
||||||
|
at::Tensor& cos_sin_cache,
|
||||||
|
bool is_neox);
|
||||||
|
|
||||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||||
// activation
|
// activation
|
||||||
@@ -196,10 +209,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
// norm
|
// norm
|
||||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||||
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||||
|
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
|
||||||
|
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
|
||||||
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||||
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||||
|
|
||||||
// topk
|
// topk
|
||||||
|
m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||||
|
m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
|
||||||
|
m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||||
|
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
|
||||||
m.def(
|
m.def(
|
||||||
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
|
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
|
||||||
"int topk_group) -> (Tensor, Tensor)");
|
"int topk_group) -> (Tensor, Tensor)");
|
||||||
@@ -294,8 +313,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||||
|
|
||||||
// rope
|
// rope
|
||||||
m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
|
m.def(
|
||||||
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
|
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
|
||||||
|
"bool is_neox) -> (Tensor, Tensor)");
|
||||||
|
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -63,10 +63,24 @@ class TestNorm(CustomTestCase):
|
|||||||
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
|
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
|
||||||
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
|
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
def _l2norm_test(self, m, n, dtype):
|
||||||
|
|
||||||
|
x = torch.randn([m, n], dtype=dtype)
|
||||||
|
hidden_size = x.size(-1)
|
||||||
|
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
|
||||||
|
variance_epsilon = 1e-6
|
||||||
|
|
||||||
|
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
|
||||||
|
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
|
||||||
|
|
||||||
|
atol = rtol = precision[ref_out.dtype]
|
||||||
|
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
def test_norm(self):
|
def test_norm(self):
|
||||||
for params in itertools.product(self.M, self.N, self.dtype):
|
for params in itertools.product(self.M, self.N, self.dtype):
|
||||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||||
self._norm_test(*params)
|
self._norm_test(*params)
|
||||||
|
self._l2norm_test(*params)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ import sgl_kernel
|
|||||||
import torch
|
import torch
|
||||||
from utils import precision
|
from utils import precision
|
||||||
|
|
||||||
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import (
|
||||||
|
DeepseekScalingRotaryEmbedding,
|
||||||
|
RotaryEmbedding,
|
||||||
|
)
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# fused rope kernel
|
# fused rope kernel
|
||||||
q_pe_clone, k_pe_clone = (
|
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||||
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
|
positions,
|
||||||
positions, q_pe_clone, k_pe_clone, cos_sin_cache
|
q_pe_clone,
|
||||||
)
|
k_pe_clone,
|
||||||
|
rope.head_size,
|
||||||
|
cos_sin_cache,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
atol = rtol = precision[q_pe.dtype]
|
atol = rtol = precision[q_pe.dtype]
|
||||||
@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
|
|||||||
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
|
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
|
||||||
torch.testing.assert_close(k_pe, k_pe_clone)
|
torch.testing.assert_close(k_pe, k_pe_clone)
|
||||||
|
|
||||||
|
def test_origin_rope(self):
|
||||||
|
def single_test(
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
):
|
||||||
|
torch.manual_seed(100)
|
||||||
|
rope_ref = RotaryEmbedding(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
).to(device)
|
||||||
|
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||||
|
query = torch.randn(
|
||||||
|
batch_size * seq_len,
|
||||||
|
num_q_heads * head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
key = torch.randn(
|
||||||
|
batch_size * seq_len,
|
||||||
|
num_kv_heads * head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_ref, key_ref = query.clone(), key.clone()
|
||||||
|
query_cpu, key_cpu = query.clone(), key.clone()
|
||||||
|
|
||||||
|
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||||
|
pos_ids, query_ref, key_ref
|
||||||
|
)
|
||||||
|
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||||
|
pos_ids,
|
||||||
|
query_cpu,
|
||||||
|
key_cpu,
|
||||||
|
rope_ref.head_size,
|
||||||
|
rope_ref.cos_sin_cache.to(query.dtype),
|
||||||
|
rope_ref.is_neox_style,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
test_config = [
|
||||||
|
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
|
||||||
|
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||||
|
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||||
|
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||||
|
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
|
||||||
|
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||||
|
]
|
||||||
|
|
||||||
|
for (
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_q_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
) in test_config:
|
||||||
|
single_test(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_q_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from utils import precision
|
|||||||
from sglang.srt.layers.moe.topk import (
|
from sglang.srt.layers.moe.topk import (
|
||||||
biased_grouped_topk_impl as native_biased_grouped_topk,
|
biased_grouped_topk_impl as native_biased_grouped_topk,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk
|
||||||
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
|
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
|
||||||
|
from sglang.srt.models.llama4 import Llama4MoE
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -94,5 +96,86 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
|||||||
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopK(CustomTestCase):
|
||||||
|
def _run_single_test(self, M, E, topk, renormalize, dtype):
|
||||||
|
torch.manual_seed(1998)
|
||||||
|
|
||||||
|
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||||
|
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||||
|
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||||
|
|
||||||
|
ref_topk_weights, ref_topk_ids = native_fused_topk(
|
||||||
|
hidden_states.float(),
|
||||||
|
gating_output.float(),
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused version
|
||||||
|
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||||
|
hidden_states, gating_output, topk, renormalize
|
||||||
|
)
|
||||||
|
|
||||||
|
res = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
ref = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||||
|
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||||
|
torch.testing.assert_close(res, ref)
|
||||||
|
|
||||||
|
def test_topk(self):
|
||||||
|
for renormalize in [True, False]:
|
||||||
|
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
|
||||||
|
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomTopK(CustomTestCase):
|
||||||
|
def _run_single_test(
|
||||||
|
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
|
||||||
|
):
|
||||||
|
torch.manual_seed(16)
|
||||||
|
|
||||||
|
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||||
|
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||||
|
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||||
|
|
||||||
|
ref_topk_weights, ref_topk_ids = native_custom_f(
|
||||||
|
hidden_states.float(),
|
||||||
|
gating_output.float(),
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused version
|
||||||
|
topk_weights, topk_ids = fused_custom_f(
|
||||||
|
hidden_states, gating_output, topk, renormalize
|
||||||
|
)
|
||||||
|
|
||||||
|
res = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
ref = torch.zeros(M, E, dtype=torch.float)
|
||||||
|
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||||
|
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||||
|
torch.testing.assert_close(res, ref)
|
||||||
|
|
||||||
|
def test_custom_topk(self):
|
||||||
|
test_custom_functions = [
|
||||||
|
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
|
||||||
|
]
|
||||||
|
for native_custom_f, fused_custom_f in test_custom_functions:
|
||||||
|
self._run_single_test(
|
||||||
|
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||||
|
)
|
||||||
|
self._run_single_test(
|
||||||
|
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||||
|
)
|
||||||
|
self._run_single_test(
|
||||||
|
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user