update sgl-kernel for EP: kernel part (#8514)
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: Ke Bao <ispobaoke@gmail.com>
This commit is contained in:
@@ -164,9 +164,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
|||||||
num_tokens_post_pad_cuda = torch.empty(
|
num_tokens_post_pad_cuda = torch.empty(
|
||||||
(1), dtype=torch.int32, device=topk_ids.device
|
(1), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
token_cnts_buffer = torch.zeros(
|
|
||||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
cumsum_buffer = torch.zeros(
|
cumsum_buffer = torch.zeros(
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
@@ -189,7 +186,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
|||||||
sorted_ids_cuda,
|
sorted_ids_cuda,
|
||||||
expert_ids_cuda,
|
expert_ids_cuda,
|
||||||
num_tokens_post_pad_cuda,
|
num_tokens_post_pad_cuda,
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
)
|
)
|
||||||
moe_align_block_size_triton(
|
moe_align_block_size_triton(
|
||||||
@@ -273,11 +269,6 @@ def sgl_moe_align_block_size_with_empty(
|
|||||||
if not pad_sorted_token_ids:
|
if not pad_sorted_token_ids:
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
|
|
||||||
token_cnts_buffer = torch.empty(
|
|
||||||
(num_experts + 1) * num_experts,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device,
|
|
||||||
)
|
|
||||||
cumsum_buffer = torch.empty(
|
cumsum_buffer = torch.empty(
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
@@ -289,7 +280,6 @@ def sgl_moe_align_block_size_with_empty(
|
|||||||
sorted_ids.clone(),
|
sorted_ids.clone(),
|
||||||
expert_ids.clone(),
|
expert_ids.clone(),
|
||||||
num_tokens_post_pad.clone(),
|
num_tokens_post_pad.clone(),
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
pad_sorted_token_ids,
|
pad_sorted_token_ids,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
*/
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
|
||||||
"pad_sorted_token_ids) -> ()");
|
"pad_sorted_token_ids) -> ()");
|
||||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ __global__ void count_and_sort_expert_tokens_kernel(
|
|||||||
const size_t stride = blockDim.x * gridDim.x;
|
const size_t stride = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
for (size_t i = tid; i < numel; i += stride) {
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
int32_t expert_id = topk_ids[i];
|
int32_t expert_id = topk_ids[i] + 1;
|
||||||
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||||
sorted_token_ids[rank_post_pad] = i;
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
}
|
}
|
||||||
@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (size_t i = tid; i < numel; i += stride) {
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
int expert_id = topk_ids[i];
|
int expert_id = topk_ids[i] + 1;
|
||||||
atomicAdd(&shared_counts[expert_id], 1);
|
atomicAdd(&shared_counts[expert_id], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +215,7 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
right = mid;
|
right = mid;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
expert_ids[i] = left - 1;
|
expert_ids[i] = left - 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pad_sorted_token_ids) {
|
if (pad_sorted_token_ids) {
|
||||||
@@ -251,7 +251,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = tid; i < numel; i += stride) {
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
|
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -277,7 +277,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
|||||||
|
|
||||||
if (threadIdx.x < num_experts) {
|
if (threadIdx.x < num_experts) {
|
||||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||||
expert_ids[i / block_size] = threadIdx.x;
|
expert_ids[i / block_size] = threadIdx.x - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,7 +294,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (size_t i = tid; i < numel; i += stride) {
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
int32_t expert_id = topk_ids[i];
|
int32_t expert_id = topk_ids[i] + 1;
|
||||||
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
||||||
sorted_token_ids[rank_post_pad] = i;
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
++tokens_cnts[threadIdx.x * num_experts + expert_id];
|
++tokens_cnts[threadIdx.x * num_experts + expert_id];
|
||||||
@@ -308,7 +308,6 @@ void moe_align_block_size(
|
|||||||
torch::Tensor sorted_token_ids,
|
torch::Tensor sorted_token_ids,
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad,
|
torch::Tensor num_tokens_post_pad,
|
||||||
torch::Tensor token_cnts_buffer,
|
|
||||||
torch::Tensor cumsum_buffer,
|
torch::Tensor cumsum_buffer,
|
||||||
bool pad_sorted_token_ids) {
|
bool pad_sorted_token_ids) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
*/
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
|
||||||
"pad_sorted_token_ids) -> ()");
|
"pad_sorted_token_ids) -> ()");
|
||||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
|
|||||||
@@ -230,7 +230,6 @@ void moe_align_block_size(
|
|||||||
torch::Tensor sorted_token_ids,
|
torch::Tensor sorted_token_ids,
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad,
|
torch::Tensor num_tokens_post_pad,
|
||||||
torch::Tensor token_cnts_buffer,
|
|
||||||
torch::Tensor cumsum_buffer,
|
torch::Tensor cumsum_buffer,
|
||||||
bool pad_sorted_token_ids);
|
bool pad_sorted_token_ids);
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ def moe_align_block_size(
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
experts_ids,
|
experts_ids,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
pad_sorted_token_ids=False,
|
pad_sorted_token_ids=False,
|
||||||
):
|
):
|
||||||
@@ -21,7 +20,6 @@ def moe_align_block_size(
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
experts_ids,
|
experts_ids,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
pad_sorted_token_ids,
|
pad_sorted_token_ids,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ def test_moe_align_block_size_compare_implementations(
|
|||||||
:, :topk
|
:, :topk
|
||||||
]
|
]
|
||||||
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
||||||
|
|
||||||
sorted_ids_cuda = torch.empty(
|
sorted_ids_cuda = torch.empty(
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
@@ -171,13 +171,8 @@ def test_moe_align_block_size_compare_implementations(
|
|||||||
num_tokens_post_pad_cuda = torch.empty(
|
num_tokens_post_pad_cuda = torch.empty(
|
||||||
(1), dtype=torch.int32, device=topk_ids.device
|
(1), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
token_cnts_buffer = torch.empty(
|
|
||||||
(num_experts + 1) * num_experts,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device,
|
|
||||||
)
|
|
||||||
cumsum_buffer = torch.empty(
|
cumsum_buffer = torch.empty(
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
num_experts + 2, dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||||
@@ -187,19 +182,18 @@ def test_moe_align_block_size_compare_implementations(
|
|||||||
|
|
||||||
moe_align_block_size(
|
moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts + 1,
|
||||||
block_size,
|
block_size,
|
||||||
sorted_ids_cuda,
|
sorted_ids_cuda,
|
||||||
expert_ids_cuda,
|
expert_ids_cuda,
|
||||||
num_tokens_post_pad_cuda,
|
num_tokens_post_pad_cuda,
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
pad_sorted_token_ids,
|
pad_sorted_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_align_block_size_triton(
|
moe_align_block_size_triton(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts + 1,
|
||||||
block_size,
|
block_size,
|
||||||
sorted_ids_triton,
|
sorted_ids_triton,
|
||||||
expert_ids_triton,
|
expert_ids_triton,
|
||||||
|
|||||||
Reference in New Issue
Block a user