fix moe_align_block_size (#2615)
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sgl-kernel"
|
name = "sgl-kernel"
|
||||||
version = "0.0.2.post9"
|
version = "0.0.2.post10"
|
||||||
description = "Kernel Library for SGLang"
|
description = "Kernel Library for SGLang"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
|
|||||||
@@ -118,31 +118,19 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
|
|||||||
}
|
}
|
||||||
|
|
||||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
||||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids,
|
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||||
torch::Tensor num_tokens_post_pad) {
|
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
// tensors
|
// tensors
|
||||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||||
|
|
||||||
const int32_t mem_tokens_cnts = ((num_experts + 1) * num_experts) * sizeof(int32_t);
|
|
||||||
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
|
|
||||||
|
|
||||||
// allocate global memory
|
|
||||||
int32_t* tokens_cnts;
|
|
||||||
int32_t* cumsum;
|
|
||||||
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
|
|
||||||
cudaMalloc(&cumsum, mem_cumsum);
|
|
||||||
|
|
||||||
// set dynamic shared mem
|
|
||||||
auto kernel = moe_align_block_size_kernel<scalar_t>;
|
auto kernel = moe_align_block_size_kernel<scalar_t>;
|
||||||
kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
num_experts, block_size, topk_ids.numel(), tokens_cnts, cumsum);
|
num_experts, block_size, topk_ids.numel(),
|
||||||
|
token_cnts_buffer.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>());
|
||||||
cudaFree(tokens_cnts);
|
|
||||||
cudaFree(cumsum);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ def moe_align_block_size(
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
experts_ids,
|
experts_ids,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
):
|
):
|
||||||
_moe_align_block_size(
|
_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -16,4 +18,6 @@ def moe_align_block_size(
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
experts_ids,
|
experts_ids,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,8 +18,22 @@ def test_moe_align_block_size():
|
|||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
|
|
||||||
|
token_cnts_buffer = torch.empty(
|
||||||
|
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
cumsum_buffer = torch.empty(
|
||||||
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
moe_align_block_size(
|
moe_align_block_size(
|
||||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user