optimize moe_align_kernel cuda (#3347)
This commit is contained in:
@@ -163,10 +163,10 @@ def calculate_diff(batch_size, seq_len):
|
||||
num_tokens_post_pad_cuda = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
token_cnts_buffer = torch.empty(
|
||||
token_cnts_buffer = torch.zeros(
|
||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
cumsum_buffer = torch.zeros(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
@@ -260,10 +260,10 @@ def benchmark(batch_size, seq_len, provider):
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
token_cnts_buffer = torch.empty(
|
||||
token_cnts_buffer = torch.zeros(
|
||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
cumsum_buffer = torch.zeros(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
|
||||
@@ -417,12 +417,12 @@ def moe_align_block_size(
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
token_cnts_buffer = torch.empty(
|
||||
token_cnts_buffer = torch.zeros(
|
||||
(num_experts + 1) * num_experts,
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
cumsum_buffer = torch.zeros(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
|
||||
@@ -24,12 +24,24 @@ limitations under the License.
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||
int32_t* cumsum_buffer, size_t numel) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* cumsum) {
|
||||
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||
__shared__ int32_t local_offsets[256];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int experts_per_warp = 8;
|
||||
@@ -72,20 +84,6 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
|
||||
// If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
|
||||
// kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
|
||||
// illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
|
||||
// results in the same issue, and a correct solution has not yet been found.
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,5 +98,15 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
||||
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||
|
||||
const int block_threads = 256;
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
auto sort_kernel = moe_token_sort_kernel<scalar_t>;
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user