fix sgl-kernel codestyle (#3563)
This commit is contained in:
@@ -33,11 +33,11 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
||||
const int batch_size, const int num_heads, const int qk_dim,
|
||||
const int v_dim) {
|
||||
extern __shared__ char smem[];
|
||||
T* q_shared = reinterpret_cast<T*>(smem);
|
||||
T* k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||
T* v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
||||
float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
||||
T* output_shared =
|
||||
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
|
||||
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||
T* __restrict__ v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
||||
float* __restrict__ new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
||||
T* __restrict__ output_shared =
|
||||
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
|
||||
|
||||
const int32_t tid = threadIdx.x;
|
||||
@@ -51,6 +51,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
||||
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
|
||||
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
|
||||
|
||||
// Load q, k, v into shared memory
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
q_shared[d] = q[qk_offset + d];
|
||||
k_shared[d] = k[qk_offset + d];
|
||||
@@ -63,33 +64,36 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
||||
|
||||
const float ratio = expf(-1.0f * slope[h]);
|
||||
|
||||
// Compute new_kv
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
T k_val = k_shared[d];
|
||||
const T k_val = k_shared[d];
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
int past_kv_idx = kv_offset + d * v_dim + e;
|
||||
T v_val = v_shared[e];
|
||||
float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
||||
int shared_idx = d * (v_dim + 1) + e;
|
||||
const int past_kv_idx = kv_offset + d * v_dim + e;
|
||||
const T v_val = v_shared[e];
|
||||
const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
new_kv_shared[shared_idx] = new_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Store new_kv to global memory
|
||||
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
|
||||
int d = idx / v_dim;
|
||||
int e = idx % v_dim;
|
||||
int shared_idx = d * (v_dim + 1) + e;
|
||||
int global_idx = kv_offset + idx;
|
||||
const int d = idx / v_dim;
|
||||
const int e = idx % v_dim;
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
const int global_idx = kv_offset + idx;
|
||||
new_kv[global_idx] = new_kv_shared[shared_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute output
|
||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < qk_dim; ++d) {
|
||||
int shared_idx = d * (v_dim + 1) + e;
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
sum += q_shared[d] * new_kv_shared[shared_idx];
|
||||
}
|
||||
output_shared[e] = static_cast<T>(sum);
|
||||
@@ -97,6 +101,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Store output to global memory
|
||||
if (tid == 0) {
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
output[v_offset + e] = output_shared[e];
|
||||
|
||||
@@ -25,8 +25,9 @@ limitations under the License.
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||
int32_t* cumsum_buffer, size_t numel) {
|
||||
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer, size_t numel) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
@@ -38,9 +39,10 @@ __global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t*
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* cumsum) {
|
||||
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
|
||||
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
@@ -106,7 +108,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
auto sort_kernel = moe_token_sort_kernel<scalar_t>;
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||
|
||||
@@ -7,13 +7,11 @@
|
||||
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
|
||||
__device__ __forceinline__ float WarpReduce(volatile float* smem, const int tid) {
|
||||
if (tid < 8) {
|
||||
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
|
||||
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
|
||||
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
|
||||
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
|
||||
}
|
||||
__device__ __forceinline__ float GroupReduce(volatile float* smem, const int tid) {
|
||||
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
|
||||
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
|
||||
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
|
||||
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
|
||||
return smem[0];
|
||||
}
|
||||
|
||||
@@ -53,7 +51,7 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
|
||||
|
||||
// Perform reduction within each group
|
||||
if (local_tid < 8) {
|
||||
WarpReduce(&s_absmax[local_group_id][0], local_tid);
|
||||
GroupReduce(&s_absmax[local_group_id][0], local_tid);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user