[kernel] opt moe align block kernel by block/warp scan algorithm (#7884)
This commit is contained in:
@@ -26,6 +26,12 @@ limitations under the License.
|
||||
#define VEC_SIZE 4
|
||||
using Vec = int4;
|
||||
|
||||
#ifndef __CUDA_ARCH__ // HIP
|
||||
#define SHFL_UP(mask, val, delta) __shfl_up((val), (delta))
|
||||
#else // CUDA
|
||||
#define SHFL_UP(mask, val, delta) __shfl_up_sync((mask), (val), (delta))
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
@@ -42,6 +48,16 @@ __global__ void count_and_sort_expert_tokens_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) {
|
||||
int original = v;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||
int n = SHFL_UP(mask, v, offset);
|
||||
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
|
||||
}
|
||||
return v - original;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
@@ -58,6 +74,7 @@ __global__ void moe_align_block_size_kernel(
|
||||
int32_t* shared_counts = smem; // [num_experts]
|
||||
int32_t* prefix = shared_counts + num_experts; // [num_experts + 1]
|
||||
int32_t* scan_buf = prefix + num_experts + 1; // [scan_size]
|
||||
int32_t* warp_sums = scan_buf + scan_size; // [<= 32]
|
||||
__shared__ int32_t s_total_tokens_post_pad;
|
||||
|
||||
const size_t tid = threadIdx.x;
|
||||
@@ -76,6 +93,7 @@ __global__ void moe_align_block_size_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate padded_cnt, write scan_buf, directly prefix sum
|
||||
int32_t padded_count = 0;
|
||||
if (tid < num_experts) {
|
||||
int32_t count = shared_counts[tid];
|
||||
@@ -83,58 +101,52 @@ __global__ void moe_align_block_size_kernel(
|
||||
scan_buf[tid] = padded_count;
|
||||
}
|
||||
|
||||
if (tid >= num_experts && tid < scan_size) {
|
||||
scan_buf[tid] = 0;
|
||||
}
|
||||
|
||||
// Intra warp prefix sum
|
||||
const int warp_id = tid / WARP_SIZE;
|
||||
const int lane_id = tid & (WARP_SIZE - 1);
|
||||
const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int warp_sum = warp_exclusive_scan(padded_count) + padded_count;
|
||||
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
// Blelloch scan
|
||||
int offset = 1;
|
||||
#pragma unroll
|
||||
for (int d = scan_size >> 1; d > 0; d >>= 1) {
|
||||
if (tid < d) {
|
||||
int ai = offset * (2 * tid + 1) - 1;
|
||||
int bi = offset * (2 * tid + 2) - 1;
|
||||
scan_buf[bi] += scan_buf[ai];
|
||||
}
|
||||
offset <<= 1;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// down-sweep
|
||||
if (tid == 0) {
|
||||
prefix[num_experts] = scan_buf[scan_size - 1];
|
||||
scan_buf[scan_size - 1] = 0;
|
||||
// warp0 accumulate all the block's prefix sum
|
||||
if (tid < WARP_SIZE) {
|
||||
int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0;
|
||||
int incl = warp_exclusive_scan(val) + val;
|
||||
warp_sums[tid] = incl;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int d = 1; d < scan_size; d <<= 1) {
|
||||
offset >>= 1;
|
||||
if (tid < d) {
|
||||
int ai = offset * (2 * tid + 1) - 1;
|
||||
int bi = offset * (2 * tid + 2) - 1;
|
||||
if (bi < scan_size) {
|
||||
int temp = scan_buf[ai];
|
||||
scan_buf[ai] = scan_buf[bi];
|
||||
scan_buf[bi] += temp;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid < num_experts) {
|
||||
prefix[tid] = scan_buf[tid];
|
||||
}
|
||||
|
||||
// Every thread obtains the whole block's sum
|
||||
if (tid == 0) {
|
||||
prefix[num_experts] = warp_sums[num_warps_for_scan - 1];
|
||||
s_total_tokens_post_pad = prefix[num_experts];
|
||||
*total_tokens_post_pad = s_total_tokens_post_pad;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Fill 0 to scan_buf extended area (tid >= num_expert)
|
||||
if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// Perform 2 level exclusive-prefix-sum to scan_buf
|
||||
int v = (tid < scan_size) ? scan_buf[tid] : 0;
|
||||
int pre = warp_exclusive_scan(v);
|
||||
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v;
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0;
|
||||
warp_sums[lane_id] = warp_exclusive_scan(val);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int offset = warp_sums[warp_id];
|
||||
if (tid < scan_size) scan_buf[tid] = pre + offset;
|
||||
__syncthreads();
|
||||
|
||||
// Write prefix[0..num_experts - 1] and cumsum
|
||||
if (tid < num_experts) prefix[tid] = scan_buf[tid];
|
||||
if (tid <= num_experts) {
|
||||
cumsum[tid] = prefix[tid];
|
||||
}
|
||||
@@ -250,9 +262,6 @@ void moe_align_block_size(
|
||||
bool pad_sorted_token_ids) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
int experts_per_warp = WARP_SIZE;
|
||||
int threads = 1024;
|
||||
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
@@ -278,7 +287,7 @@ void moe_align_block_size(
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
const size_t scan_size = next_pow2(num_experts);
|
||||
const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size) * sizeof(int32_t);
|
||||
const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t);
|
||||
|
||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
|
||||
Reference in New Issue
Block a user