update sgl-kernel for EP: kernel part (#8514)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Ke Bao <ispobaoke@gmail.com>
This commit is contained in:
Cheng Wan
2025-07-30 22:19:55 -07:00
committed by GitHub
parent 59aab76f0a
commit a5f5ab4030
7 changed files with 12 additions and 32 deletions

View File

@@ -36,7 +36,7 @@ __global__ void count_and_sort_expert_tokens_kernel(
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t expert_id = topk_ids[i] + 1;
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
__syncthreads();
for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i];
int expert_id = topk_ids[i] + 1;
atomicAdd(&shared_counts[expert_id], 1);
}
@@ -215,7 +215,7 @@ __global__ void moe_align_block_size_kernel(
right = mid;
}
}
expert_ids[i] = left - 1;
expert_ids[i] = left - 2;
}
if (pad_sorted_token_ids) {
@@ -251,7 +251,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
for (size_t i = tid; i < numel; i += stride) {
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1];
}
__syncthreads();
@@ -277,7 +277,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
expert_ids[i / block_size] = threadIdx.x - 1;
}
}
@@ -294,7 +294,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
__syncthreads();
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t expert_id = topk_ids[i] + 1;
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x * num_experts + expert_id];
@@ -308,7 +308,6 @@ void moe_align_block_size(
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer,
bool pad_sorted_token_ids) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();