From 3efbdf68b91e29245e41702b9cbe60aca7cd6351 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 14 Feb 2025 18:05:52 +0800 Subject: [PATCH] fix sgl-kernel codestyle (#3563) --- .../csrc/lightning_attention_decode_kernel.cu | 35 +++++++++++-------- .../src/sgl-kernel/csrc/moe_align_kernel.cu | 14 ++++---- .../csrc/per_token_group_quant_fp8.cu | 14 ++++---- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu index e9fc1c0ec..02c50498e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -33,11 +33,11 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, const int batch_size, const int num_heads, const int qk_dim, const int v_dim) { extern __shared__ char smem[]; - T* q_shared = reinterpret_cast(smem); - T* k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); - T* v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T)); - float* new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T)); - T* output_shared = + T* __restrict__ q_shared = reinterpret_cast(smem); + T* __restrict__ k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); + T* __restrict__ v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T)); + float* __restrict__ new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T)); + T* __restrict__ output_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float)); const int32_t tid = threadIdx.x; @@ -51,6 +51,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, const int32_t v_offset = b * num_heads * v_dim + h * v_dim; const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim; + // Load q, k, v into shared memory for (int d = tid; d < qk_dim; d += blockDim.x) { q_shared[d] = q[qk_offset + d]; k_shared[d] = k[qk_offset + d]; @@ -63,33 +64,36 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, const float ratio = expf(-1.0f * slope[h]); + // Compute new_kv for (int d = tid; d < qk_dim; d += blockDim.x) { - T k_val = k_shared[d]; + const T k_val = k_shared[d]; for (int e = 0; e < v_dim; ++e) { - int past_kv_idx = kv_offset + d * v_dim + e; - T v_val = v_shared[e]; - float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; - int shared_idx = d * (v_dim + 1) + e; + const int past_kv_idx = kv_offset + d * v_dim + e; + const T v_val = v_shared[e]; + const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; + const int shared_idx = d * (v_dim + 1) + e; new_kv_shared[shared_idx] = new_val; } } __syncthreads(); + // Store new_kv to global memory for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) { - int d = idx / v_dim; - int e = idx % v_dim; - int shared_idx = d * (v_dim + 1) + e; - int global_idx = kv_offset + idx; + const int d = idx / v_dim; + const int e = idx % v_dim; + const int shared_idx = d * (v_dim + 1) + e; + const int global_idx = kv_offset + idx; new_kv[global_idx] = new_kv_shared[shared_idx]; } __syncthreads(); + // Compute output for (int e = tid; e < v_dim; e += blockDim.x) { float sum = 0.0f; for (int d = 0; d < qk_dim; ++d) { - int shared_idx = d * (v_dim + 1) + e; + const int shared_idx = d * (v_dim + 1) + e; sum += q_shared[d] * new_kv_shared[shared_idx]; } output_shared[e] = static_cast(sum); @@ -97,6 +101,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, __syncthreads(); + // Store output to global memory if (tid == 0) { for (int e = 0; e < v_dim; ++e) { output[v_offset + e] = output_shared[e]; diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index 6346efbd3..473aae6f5 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -25,8 +25,9 @@ limitations under the License. #define WARP_SIZE 32 template -__global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, - int32_t* cumsum_buffer, size_t numel) { +__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; @@ -38,9 +39,10 @@ __global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* } template -__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, - int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* cumsum) { +__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { __shared__ int32_t shared_counts[WARP_SIZE][8]; const int warp_id = threadIdx.x / WARP_SIZE; @@ -106,7 +108,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b const int max_blocks = 65535; const int actual_blocks = std::min(num_blocks, max_blocks); - auto sort_kernel = moe_token_sort_kernel; + auto sort_kernel = count_and_sort_expert_tokens_kernel; sort_kernel<<>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), cumsum_buffer.data_ptr(), topk_ids.numel()); diff --git a/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu index 894d1d332..5afe03801 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu @@ -7,13 +7,11 @@ using FP8_TYPE = c10::Float8_e4m3fn; -__device__ __forceinline__ float WarpReduce(volatile float* smem, const int tid) { - if (tid < 8) { - smem[tid] = fmaxf(smem[tid], smem[tid + 8]); - if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]); - if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]); - if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]); - } +__device__ __forceinline__ float GroupReduce(volatile float* smem, const int tid) { + smem[tid] = fmaxf(smem[tid], smem[tid + 8]); + if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]); + if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]); + if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]); return smem[0]; } @@ -53,7 +51,7 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo // Perform reduction within each group if (local_tid < 8) { - WarpReduce(&s_absmax[local_group_id][0], local_tid); + GroupReduce(&s_absmax[local_group_id][0], local_tid); } __syncthreads();