fix sgl-kernel codestyle (#3563)
This commit is contained in:
@@ -33,11 +33,11 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
|||||||
const int batch_size, const int num_heads, const int qk_dim,
|
const int batch_size, const int num_heads, const int qk_dim,
|
||||||
const int v_dim) {
|
const int v_dim) {
|
||||||
extern __shared__ char smem[];
|
extern __shared__ char smem[];
|
||||||
T* q_shared = reinterpret_cast<T*>(smem);
|
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
|
||||||
T* k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||||
T* v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
T* __restrict__ v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
||||||
float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
float* __restrict__ new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
||||||
T* output_shared =
|
T* __restrict__ output_shared =
|
||||||
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
|
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
|
||||||
|
|
||||||
const int32_t tid = threadIdx.x;
|
const int32_t tid = threadIdx.x;
|
||||||
@@ -51,6 +51,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
|||||||
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
|
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
|
||||||
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
|
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
|
||||||
|
|
||||||
|
// Load q, k, v into shared memory
|
||||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||||
q_shared[d] = q[qk_offset + d];
|
q_shared[d] = q[qk_offset + d];
|
||||||
k_shared[d] = k[qk_offset + d];
|
k_shared[d] = k[qk_offset + d];
|
||||||
@@ -63,33 +64,36 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
|||||||
|
|
||||||
const float ratio = expf(-1.0f * slope[h]);
|
const float ratio = expf(-1.0f * slope[h]);
|
||||||
|
|
||||||
|
// Compute new_kv
|
||||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||||
T k_val = k_shared[d];
|
const T k_val = k_shared[d];
|
||||||
for (int e = 0; e < v_dim; ++e) {
|
for (int e = 0; e < v_dim; ++e) {
|
||||||
int past_kv_idx = kv_offset + d * v_dim + e;
|
const int past_kv_idx = kv_offset + d * v_dim + e;
|
||||||
T v_val = v_shared[e];
|
const T v_val = v_shared[e];
|
||||||
float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
||||||
int shared_idx = d * (v_dim + 1) + e;
|
const int shared_idx = d * (v_dim + 1) + e;
|
||||||
new_kv_shared[shared_idx] = new_val;
|
new_kv_shared[shared_idx] = new_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// Store new_kv to global memory
|
||||||
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
|
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
|
||||||
int d = idx / v_dim;
|
const int d = idx / v_dim;
|
||||||
int e = idx % v_dim;
|
const int e = idx % v_dim;
|
||||||
int shared_idx = d * (v_dim + 1) + e;
|
const int shared_idx = d * (v_dim + 1) + e;
|
||||||
int global_idx = kv_offset + idx;
|
const int global_idx = kv_offset + idx;
|
||||||
new_kv[global_idx] = new_kv_shared[shared_idx];
|
new_kv[global_idx] = new_kv_shared[shared_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// Compute output
|
||||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int d = 0; d < qk_dim; ++d) {
|
for (int d = 0; d < qk_dim; ++d) {
|
||||||
int shared_idx = d * (v_dim + 1) + e;
|
const int shared_idx = d * (v_dim + 1) + e;
|
||||||
sum += q_shared[d] * new_kv_shared[shared_idx];
|
sum += q_shared[d] * new_kv_shared[shared_idx];
|
||||||
}
|
}
|
||||||
output_shared[e] = static_cast<T>(sum);
|
output_shared[e] = static_cast<T>(sum);
|
||||||
@@ -97,6 +101,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// Store output to global memory
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
for (int e = 0; e < v_dim; ++e) {
|
for (int e = 0; e < v_dim; ++e) {
|
||||||
output[v_offset + e] = output_shared[e];
|
output[v_offset + e] = output_shared[e];
|
||||||
|
|||||||
@@ -25,8 +25,9 @@ limitations under the License.
|
|||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
|
||||||
int32_t* cumsum_buffer, size_t numel) {
|
int32_t* __restrict__ sorted_token_ids,
|
||||||
|
int32_t* __restrict__ cumsum_buffer, size_t numel) {
|
||||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const size_t stride = blockDim.x * gridDim.x;
|
const size_t stride = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
@@ -38,9 +39,10 @@ __global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t*
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
|
||||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||||
int32_t block_size, size_t numel, int32_t* cumsum) {
|
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||||
|
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
|
||||||
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
@@ -106,7 +108,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
|||||||
const int max_blocks = 65535;
|
const int max_blocks = 65535;
|
||||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||||
|
|
||||||
auto sort_kernel = moe_token_sort_kernel<scalar_t>;
|
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
|
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||||
|
|||||||
@@ -7,13 +7,11 @@
|
|||||||
|
|
||||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||||
|
|
||||||
__device__ __forceinline__ float WarpReduce(volatile float* smem, const int tid) {
|
__device__ __forceinline__ float GroupReduce(volatile float* smem, const int tid) {
|
||||||
if (tid < 8) {
|
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
|
||||||
smem[tid] = fmaxf(smem[tid], smem[tid + 8]);
|
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
|
||||||
if (tid < 4) smem[tid] = fmaxf(smem[tid], smem[tid + 4]);
|
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
|
||||||
if (tid < 2) smem[tid] = fmaxf(smem[tid], smem[tid + 2]);
|
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
|
||||||
if (tid < 1) smem[tid] = fmaxf(smem[tid], smem[tid + 1]);
|
|
||||||
}
|
|
||||||
return smem[0];
|
return smem[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +51,7 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
|
|||||||
|
|
||||||
// Perform reduction within each group
|
// Perform reduction within each group
|
||||||
if (local_tid < 8) {
|
if (local_tid < 8) {
|
||||||
WarpReduce(&s_absmax[local_group_id][0], local_tid);
|
GroupReduce(&s_absmax[local_group_id][0], local_tid);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user