Add CPU optimized kernels for topk and rope fusions (#6456)
This commit is contained in:
@@ -4,6 +4,67 @@
|
||||
namespace {
|
||||
|
||||
// NB: avoid using `at::vec::map<>` on bfloat16 or half
|
||||
// Llama4TextL2Norm
|
||||
template <typename scalar_t>
|
||||
void l2norm_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
||||
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
float sum_val = float(0);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
sum_val += x_val * x_val;
|
||||
}
|
||||
|
||||
sum_val += vec_reduce_sum(sum_fvec);
|
||||
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
||||
const fVec scale_fvec = fVec(rsqrt_var);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
x_fvec0 = x_fvec0 * scale_fvec;
|
||||
x_fvec1 = x_fvec1 * scale_fvec;
|
||||
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(out_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
template <typename scalar_t>
|
||||
void rmsnorm_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
@@ -160,6 +221,22 @@ void fused_add_rmsnorm_kernel_impl(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// input : {batch_size, hidden_size}
|
||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));
|
||||
|
||||
CHECK_INPUT(input);
|
||||
CHECK_DIM(2, input);
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
at::Tensor output = at::empty_like(input);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
|
||||
l2norm_kernel_impl<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), batch_size, hidden_size, eps);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
// input : {batch_size, hidden_size}
|
||||
// weight: {hidden_size}
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||
|
||||
@@ -4,126 +4,343 @@
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
void rope_kernel_impl(
|
||||
scalar_t* __restrict__ q_pe_out,
|
||||
scalar_t* __restrict__ k_pe_out,
|
||||
int64_t* __restrict__ t_pos,
|
||||
scalar_t* __restrict__ q_pe,
|
||||
scalar_t* __restrict__ k_pe,
|
||||
scalar_t* __restrict__ t_emb_pos,
|
||||
int64_t seq_len,
|
||||
int64_t num_head,
|
||||
void rotary_embedding_3D_kernel_impl(
|
||||
scalar_t* __restrict__ query_out,
|
||||
scalar_t* __restrict__ key_out,
|
||||
int64_t* __restrict__ positions,
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ cos_sin_cache,
|
||||
int64_t num_tokens,
|
||||
int64_t num_heads,
|
||||
int64_t num_kv_heads,
|
||||
int64_t head_size,
|
||||
int64_t rotary_dim,
|
||||
int64_t HR,
|
||||
int64_t q_pe_stride_s,
|
||||
int64_t out_stride_qs,
|
||||
int64_t out_stride_ks,
|
||||
int64_t HK,
|
||||
int64_t k_pe_stride_s,
|
||||
int64_t q_pe_stride_n,
|
||||
int64_t out_stride_qn) {
|
||||
int64_t query_stride_s,
|
||||
int64_t query_out_stride_s,
|
||||
int64_t key_out_stride_s,
|
||||
int64_t key_stride_s,
|
||||
int64_t query_stride_h,
|
||||
int64_t query_out_stride_h) {
|
||||
int64_t HR = rotary_dim;
|
||||
int64_t HK = rotary_dim;
|
||||
int64_t COFF = HR / 2;
|
||||
at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t seq{0}, head_id{0};
|
||||
data_index_init(begin, seq, seq_len, head_id, num_head);
|
||||
data_index_init(begin, seq, num_tokens, head_id, num_heads);
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n;
|
||||
int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn;
|
||||
int64_t out_offset_k = seq * out_stride_ks;
|
||||
int64_t in_offset_q = seq * query_stride_s + head_id * query_stride_h;
|
||||
int64_t out_offset_q = seq * query_out_stride_s + head_id * query_out_stride_h;
|
||||
int64_t out_offset_k = seq * key_out_stride_s;
|
||||
int64_t p = 0;
|
||||
scalar_t* sin_start = nullptr;
|
||||
scalar_t* cos_start = nullptr;
|
||||
// step 0) get the rotary position embedding for the current position
|
||||
p = t_pos[seq];
|
||||
sin_start = t_emb_pos + p * HR + COFF;
|
||||
cos_start = t_emb_pos + p * HR;
|
||||
p = positions[seq];
|
||||
sin_start = cos_sin_cache + p * HR + COFF;
|
||||
cos_start = cos_sin_cache + p * HR;
|
||||
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
|
||||
// head of query/key
|
||||
for (int64_t h = 0; h < rotary_dim; h += 2) {
|
||||
scalar_t cos = cos_start[h >> 1];
|
||||
scalar_t sin = sin_start[h >> 1];
|
||||
scalar_t in1 = q_pe[in_offset_q + h];
|
||||
scalar_t in2 = q_pe[in_offset_q + h + 1];
|
||||
scalar_t in1 = query[in_offset_q + h];
|
||||
scalar_t in2 = query[in_offset_q + h + 1];
|
||||
scalar_t out1 = in1 * cos - in2 * sin;
|
||||
scalar_t out2 = in2 * cos + in1 * sin;
|
||||
q_pe_out[out_offset_q + h] = out1;
|
||||
q_pe_out[out_offset_q + h + 1] = out2;
|
||||
query_out[out_offset_q + h] = out1;
|
||||
query_out[out_offset_q + h + 1] = out2;
|
||||
}
|
||||
for (int64_t h = 0; h < HK; h += 2) {
|
||||
scalar_t cos = cos_start[h >> 1];
|
||||
scalar_t sin = sin_start[h >> 1];
|
||||
int64_t k_pe_offset = seq * k_pe_stride_s;
|
||||
scalar_t in1_k = k_pe[k_pe_offset + h];
|
||||
scalar_t in2_k = k_pe[k_pe_offset + h + 1];
|
||||
int64_t k_pe_offset = seq * key_stride_s;
|
||||
scalar_t in1_k = key[k_pe_offset + h];
|
||||
scalar_t in2_k = key[k_pe_offset + h + 1];
|
||||
scalar_t out1_k = in1_k * cos - in2_k * sin;
|
||||
scalar_t out2_k = in2_k * cos + in1_k * sin;
|
||||
k_pe_out[out_offset_k + h] = out1_k;
|
||||
k_pe_out[out_offset_k + h + 1] = out2_k;
|
||||
key_out[out_offset_k + h] = out1_k;
|
||||
key_out[out_offset_k + h + 1] = out2_k;
|
||||
}
|
||||
// move to the next index
|
||||
data_index_step(seq, seq_len, head_id, num_head);
|
||||
data_index_step(seq, num_tokens, head_id, num_heads);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_neox_2D_kernel_impl(
|
||||
int64_t* __restrict__ positions,
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ cos_sin_cache,
|
||||
int64_t rotary_dim,
|
||||
int64_t query_stride_s,
|
||||
int64_t key_stride_s,
|
||||
int64_t num_heads,
|
||||
int64_t num_kv_heads,
|
||||
int64_t head_size,
|
||||
int64_t num_tokens) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int64_t bVecSize = bVec::size();
|
||||
|
||||
int64_t embed_dim = rotary_dim / 2;
|
||||
bool flag = (embed_dim % bVecSize == 0);
|
||||
int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize;
|
||||
|
||||
auto compute_loop = [&](int64_t token_head, scalar_t* cache_ptr, scalar_t* qk) {
|
||||
int64_t j = 0;
|
||||
for (; j < loop_upper; j += bVecSize) {
|
||||
int64_t rot_offset = j;
|
||||
int64_t x_index = rot_offset;
|
||||
int64_t y_index = embed_dim + rot_offset;
|
||||
|
||||
int64_t out_x = token_head + x_index;
|
||||
int64_t out_y = token_head + y_index;
|
||||
|
||||
bVec _cos = bVec::loadu(cache_ptr + x_index);
|
||||
bVec _sin = bVec::loadu(cache_ptr + y_index);
|
||||
|
||||
bVec _q_x = bVec::loadu(qk + out_x);
|
||||
bVec _q_y = bVec::loadu(qk + out_y);
|
||||
fVec _cos_0, _cos_1;
|
||||
std::tie(_cos_0, _cos_1) = at::vec::convert_to_float(_cos);
|
||||
fVec _sin_0, _sin_1;
|
||||
std::tie(_sin_0, _sin_1) = at::vec::convert_to_float(_sin);
|
||||
fVec _q_x_0, _q_x_1;
|
||||
std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x);
|
||||
fVec _q_y_0, _q_y_1;
|
||||
std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y);
|
||||
|
||||
auto out1_0 = _q_x_0 * _cos_0 - _q_y_0 * _sin_0;
|
||||
auto out1_1 = _q_x_1 * _cos_1 - _q_y_1 * _sin_1;
|
||||
auto out1 = convert_from_float_ext<scalar_t>(out1_0, out1_1);
|
||||
out1.store(qk + out_x);
|
||||
|
||||
auto out2_0 = _q_y_0 * _cos_0 + _q_x_0 * _sin_0;
|
||||
auto out2_1 = _q_y_1 * _cos_1 + _q_x_1 * _sin_1;
|
||||
auto out2 = convert_from_float_ext<scalar_t>(out2_0, out2_1);
|
||||
out2.store(qk + out_y);
|
||||
}
|
||||
if (!flag) {
|
||||
for (; j < embed_dim; ++j) {
|
||||
int64_t x_index = j;
|
||||
int64_t y_index = embed_dim + j;
|
||||
|
||||
int64_t out_x = token_head + x_index;
|
||||
int64_t out_y = token_head + y_index;
|
||||
|
||||
float _cos = cache_ptr[x_index];
|
||||
float _sin = cache_ptr[y_index];
|
||||
|
||||
float _q_x = qk[out_x];
|
||||
float _q_y = qk[out_y];
|
||||
|
||||
qk[out_x] = _q_x * _cos - _q_y * _sin;
|
||||
qk[out_y] = _q_y * _cos + _q_x * _sin;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
int64_t pos = positions[token_idx];
|
||||
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||
|
||||
for (int64_t i = 0; i < num_heads; ++i) {
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
|
||||
compute_loop(token_head, cache_ptr, query);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < num_kv_heads; ++i) {
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
|
||||
compute_loop(token_head, cache_ptr, key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_2D_kernel_impl(
|
||||
int64_t* __restrict__ positions,
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ cos_sin_cache,
|
||||
int64_t rotary_dim,
|
||||
int64_t query_stride_s,
|
||||
int64_t key_stride_s,
|
||||
int64_t num_heads,
|
||||
int64_t num_kv_heads,
|
||||
int64_t head_size,
|
||||
int64_t num_tokens) {
|
||||
int64_t embed_dim = rotary_dim / 2;
|
||||
|
||||
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t token_idx = {0}, i = {0};
|
||||
data_index_init(begin, token_idx, num_tokens, i, num_heads);
|
||||
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||
int64_t pos = positions[token_idx];
|
||||
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||
scalar_t* cos_cache_ptr = cache_ptr;
|
||||
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
|
||||
scalar_t* head_query = token_head + query;
|
||||
for (int64_t j = 0; j < embed_dim; j += 1) {
|
||||
int64_t rot_offset = j;
|
||||
int64_t x_index = 2 * rot_offset;
|
||||
int64_t y_index = 2 * rot_offset + 1;
|
||||
|
||||
float cos = cos_cache_ptr[rot_offset];
|
||||
float sin = sin_cache_ptr[rot_offset];
|
||||
|
||||
float x = head_query[x_index];
|
||||
float y = head_query[y_index];
|
||||
|
||||
head_query[x_index] = x * cos - y * sin;
|
||||
head_query[y_index] = y * cos + x * sin;
|
||||
}
|
||||
data_index_step(token_idx, num_tokens, i, num_heads);
|
||||
}
|
||||
});
|
||||
|
||||
at::parallel_for(0, num_tokens * num_kv_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t token_idx{0}, i = {0};
|
||||
data_index_init(begin, token_idx, num_tokens, i, num_kv_heads);
|
||||
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||
int64_t pos = positions[token_idx];
|
||||
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||
scalar_t* cos_cache_ptr = cache_ptr;
|
||||
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
|
||||
scalar_t* head_key = key + token_head;
|
||||
for (int64_t j = 0; j < embed_dim; j += 1) {
|
||||
int64_t rot_offset = j;
|
||||
int64_t x_index = 2 * rot_offset;
|
||||
int64_t y_index = 2 * rot_offset + 1;
|
||||
|
||||
float cos = cos_cache_ptr[rot_offset];
|
||||
float sin = sin_cache_ptr[rot_offset];
|
||||
|
||||
float x = head_key[x_index];
|
||||
float y = head_key[y_index];
|
||||
|
||||
head_key[x_index] = x * cos - y * sin;
|
||||
head_key[y_index] = y * cos + x * sin;
|
||||
}
|
||||
data_index_step(token_idx, num_tokens, i, num_kv_heads);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos}));
|
||||
CHECK_INPUT(t_pos);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe);
|
||||
CHECK_INPUT(t_emb_pos);
|
||||
CHECK_DIM(1, t_pos);
|
||||
CHECK_DIM(3, q_pe);
|
||||
CHECK_DIM(3, k_pe);
|
||||
CHECK_DIM(2, t_emb_pos);
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||
at::Tensor& positions,
|
||||
at::Tensor& query,
|
||||
at::Tensor& key,
|
||||
int64_t head_size,
|
||||
at::Tensor& cos_sin_cache,
|
||||
bool is_neox) {
|
||||
RECORD_FUNCTION("sgl-kernel::rotary_embedding_cpu", std::vector<c10::IValue>({query, key}));
|
||||
CHECK_DIM(1, positions);
|
||||
const auto input_dim = query.dim();
|
||||
const auto input_dtype = query.scalar_type();
|
||||
TORCH_CHECK(
|
||||
input_dim == 2 || input_dim == 3,
|
||||
" Query/Key must be 2D [num_tokens, num_heads*head_size] or 3D [num_tokens, num_heads, head_size] tensor");
|
||||
CHECK_DIM(2, cos_sin_cache);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
|
||||
|
||||
int64_t seq_len = q_pe.size(0);
|
||||
int64_t num_head = q_pe.size(1);
|
||||
int64_t rotary_dim = q_pe.size(2);
|
||||
int64_t HK = k_pe.size(2);
|
||||
int64_t HR = t_emb_pos.size(1);
|
||||
CHECK_EQ(HR, rotary_dim);
|
||||
CHECK_EQ(k_pe.size(0), seq_len);
|
||||
CHECK_EQ(k_pe.size(1), 1);
|
||||
CHECK_EQ(t_pos.size(0), seq_len);
|
||||
CHECK_EQ(HK, rotary_dim);
|
||||
int64_t rotary_dim = cos_sin_cache.size(1);
|
||||
if (input_dim == 3) {
|
||||
// TODO: add support for head_dim != rotary_dim case when input_dim=3
|
||||
CHECK_EQ(query.size(-1), rotary_dim);
|
||||
// TODO: add support for kv_head != 1
|
||||
CHECK_EQ(key.size(1), 1);
|
||||
}
|
||||
|
||||
at::Tensor q_pe_out = at::empty_like(q_pe);
|
||||
at::Tensor k_pe_out = at::empty_like(k_pe);
|
||||
int64_t q_pe_stride_s = q_pe.stride(0);
|
||||
int64_t q_pe_stride_n = q_pe.stride(1);
|
||||
int64_t k_pe_stride_s = k_pe.stride(0);
|
||||
int64_t out_stride_qs = q_pe_out.stride(0);
|
||||
int64_t out_stride_qn = q_pe_out.stride(1);
|
||||
int64_t out_stride_ks = k_pe_out.stride(0);
|
||||
int64_t num_tokens = positions.numel();
|
||||
CHECK_EQ(key.size(0), num_tokens);
|
||||
CHECK_EQ(query.size(0), num_tokens);
|
||||
|
||||
const auto input_dtype = q_pe.scalar_type();
|
||||
TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type());
|
||||
TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type");
|
||||
TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type");
|
||||
TORCH_CHECK(positions.scalar_type() == at::kLong, "expect positions to be int64, got ", positions.scalar_type());
|
||||
TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type");
|
||||
TORCH_CHECK(input_dtype == cos_sin_cache.scalar_type(), "query and cos_sin_cache must have the same data type");
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] {
|
||||
rope_kernel_impl<scalar_t>(
|
||||
q_pe_out.data_ptr<scalar_t>(),
|
||||
k_pe_out.data_ptr<scalar_t>(),
|
||||
t_pos.data_ptr<int64_t>(),
|
||||
q_pe.data_ptr<scalar_t>(),
|
||||
k_pe.data_ptr<scalar_t>(),
|
||||
t_emb_pos.data_ptr<scalar_t>(),
|
||||
seq_len,
|
||||
num_head,
|
||||
rotary_dim,
|
||||
HR,
|
||||
q_pe_stride_s,
|
||||
out_stride_qs,
|
||||
out_stride_ks,
|
||||
HK,
|
||||
k_pe_stride_s,
|
||||
q_pe_stride_n,
|
||||
out_stride_qn);
|
||||
int64_t num_heads = input_dim == 2 ? query.size(-1) / head_size : query.size(1);
|
||||
int64_t num_kv_heads = input_dim == 2 ? key.size(-1) / head_size : key.size(1);
|
||||
int64_t key_stride_s = key.stride(0);
|
||||
int64_t query_stride_s = query.stride(0);
|
||||
|
||||
// input stride of num head dim is meaningful only when input dim = 3
|
||||
int64_t query_stride_h = input_dim == 3 ? query.stride(1) : -1;
|
||||
at::Tensor query_out = at::empty_like(query);
|
||||
at::Tensor key_out = at::empty_like(key);
|
||||
int64_t query_out_stride_s = query_out.stride(0);
|
||||
int64_t key_out_stride_s = key_out.stride(0);
|
||||
// output stride of num head dim is meaningful only when input dim = 3
|
||||
int64_t query_out_stride_h = input_dim == 3 ? query_out.stride(1) : -1;
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_embedding_cpu", [&] {
|
||||
if (input_dim == 2) {
|
||||
if (is_neox) {
|
||||
rotary_embedding_neox_2D_kernel_impl<scalar_t>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rotary_dim,
|
||||
query_stride_s,
|
||||
key_stride_s,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
rotary_embedding_2D_kernel_impl<scalar_t>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rotary_dim,
|
||||
query_stride_s,
|
||||
key_stride_s,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
num_tokens);
|
||||
}
|
||||
query_out = query;
|
||||
key_out = key;
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
is_neox == false, " Query/Key with 3D [num_tokens, num_heads, head_size] does not support neox rope yet");
|
||||
// TODO: add neox style support for rope impl with 3D inputs
|
||||
rotary_embedding_3D_kernel_impl<scalar_t>(
|
||||
query_out.data_ptr<scalar_t>(),
|
||||
key_out.data_ptr<scalar_t>(),
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
rotary_dim,
|
||||
query_stride_s,
|
||||
query_out_stride_s,
|
||||
key_out_stride_s,
|
||||
key_stride_s,
|
||||
query_stride_h,
|
||||
query_out_stride_h);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(q_pe_out, k_pe_out);
|
||||
return std::make_tuple(query_out, key_out);
|
||||
}
|
||||
|
||||
@@ -157,6 +157,101 @@ inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void topk_sigmoid_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
bool renormalize) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
at::vec::convert<scalar_t, float>(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS);
|
||||
|
||||
float gmax = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group);
|
||||
|
||||
// find position of first max,
|
||||
// note that we may have multiple max values.
|
||||
int first_max_idx = -1;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
if (scores[e] == gmax) {
|
||||
first_max_idx = e;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// scalar sigmoid
|
||||
topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax));
|
||||
topk_ids[i] = first_max_idx;
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void topk_softmax_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
bool renormalize) {
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
queue[e] = {scores[e], e};
|
||||
}
|
||||
|
||||
std::partial_sort(
|
||||
queue.begin(),
|
||||
queue.begin() + num_experts_per_group,
|
||||
queue.end(),
|
||||
[](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; });
|
||||
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] = queue[j].first;
|
||||
topk_ids[i * topk + j] = queue[j].second;
|
||||
}
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
inline void
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
|
||||
@@ -293,6 +388,24 @@ void biased_grouped_topk_kernel_impl(
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
|
||||
topk_sigmoid_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
|
||||
topk_softmax_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
@@ -306,6 +419,114 @@ void biased_grouped_topk_kernel_impl(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||
RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case");
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||
RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
// grouped topk for DeepSeek V2
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
|
||||
@@ -23,6 +23,9 @@ limitations under the License.
|
||||
// silu_and_mul
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||
|
||||
// l2norm
|
||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
||||
|
||||
// rmsnorm
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
||||
|
||||
@@ -30,6 +33,11 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
|
||||
|
||||
// topk
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
@@ -185,8 +193,13 @@ void shm_allreduce(
|
||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
|
||||
|
||||
// rope
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||
at::Tensor& positions,
|
||||
at::Tensor& query,
|
||||
at::Tensor& key,
|
||||
int64_t head_size,
|
||||
at::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// activation
|
||||
@@ -196,10 +209,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// norm
|
||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
|
||||
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||
|
||||
// topk
|
||||
m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||
m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
|
||||
m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
|
||||
m.def(
|
||||
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
|
||||
"int topk_group) -> (Tensor, Tensor)");
|
||||
@@ -294,8 +313,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||
|
||||
// rope
|
||||
m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
|
||||
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
|
||||
m.def(
|
||||
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
|
||||
"bool is_neox) -> (Tensor, Tensor)");
|
||||
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
@@ -63,10 +63,24 @@ class TestNorm(CustomTestCase):
|
||||
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
|
||||
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
|
||||
|
||||
def _l2norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
hidden_size = x.size(-1)
|
||||
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
|
||||
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
def test_norm(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._norm_test(*params)
|
||||
self._l2norm_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,7 +4,10 @@ import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
|
||||
from sglang.srt.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
|
||||
)
|
||||
|
||||
# fused rope kernel
|
||||
q_pe_clone, k_pe_clone = (
|
||||
torch.ops.sgl_kernel.rotary_position_embedding_cpu(
|
||||
positions, q_pe_clone, k_pe_clone, cos_sin_cache
|
||||
)
|
||||
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
positions,
|
||||
q_pe_clone,
|
||||
k_pe_clone,
|
||||
rope.head_size,
|
||||
cos_sin_cache,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[q_pe.dtype]
|
||||
@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
|
||||
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
|
||||
torch.testing.assert_close(k_pe, k_pe_clone)
|
||||
|
||||
def test_origin_rope(self):
|
||||
def single_test(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
torch.manual_seed(100)
|
||||
rope_ref = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
).to(device)
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_q_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
query_ref, key_ref = query.clone(), key.clone()
|
||||
query_cpu, key_cpu = query.clone(), key.clone()
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||
pos_ids, query_ref, key_ref
|
||||
)
|
||||
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
pos_ids,
|
||||
query_cpu,
|
||||
key_cpu,
|
||||
rope_ref.head_size,
|
||||
rope_ref.cos_sin_cache.to(query.dtype),
|
||||
rope_ref.is_neox_style,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
test_config = [
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
]
|
||||
|
||||
for (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
) in test_config:
|
||||
single_test(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,7 +8,9 @@ from utils import precision
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
biased_grouped_topk_impl as native_biased_grouped_topk,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk
|
||||
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
|
||||
from sglang.srt.models.llama4 import Llama4MoE
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -94,5 +96,86 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
||||
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, topk, renormalize, dtype):
|
||||
torch.manual_seed(1998)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_fused_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestCustomTopK(CustomTestCase):
|
||||
def _run_single_test(
|
||||
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
|
||||
):
|
||||
torch.manual_seed(16)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_custom_f(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = fused_custom_f(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_custom_topk(self):
|
||||
test_custom_functions = [
|
||||
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
|
||||
]
|
||||
for native_custom_f, fused_custom_f in test_custom_functions:
|
||||
self._run_single_test(
|
||||
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user