Remove the vllm dependency from the moe_align function (#4164)
Co-authored-by: Hongbosherlock <hongbosherlock@gmail.com>
This commit is contained in:
@@ -47,18 +47,18 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
int32_t* __restrict__ expert_ids,
|
int32_t* __restrict__ expert_ids,
|
||||||
int32_t* __restrict__ total_tokens_post_pad,
|
int32_t* __restrict__ total_tokens_post_pad,
|
||||||
int32_t num_experts,
|
int32_t num_experts,
|
||||||
|
int32_t experts_per_warp,
|
||||||
int32_t block_size,
|
int32_t block_size,
|
||||||
size_t numel,
|
size_t numel,
|
||||||
int32_t* __restrict__ cumsum) {
|
int32_t* __restrict__ cumsum) {
|
||||||
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
extern __shared__ int32_t shared_counts[];
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int experts_per_warp = 8;
|
|
||||||
const int my_expert_start = warp_id * experts_per_warp;
|
const int my_expert_start = warp_id * experts_per_warp;
|
||||||
|
|
||||||
for (int i = 0; i < experts_per_warp; ++i) {
|
for (int i = 0; i < experts_per_warp; ++i) {
|
||||||
if (my_expert_start + i < num_experts) {
|
if (my_expert_start + i < num_experts) {
|
||||||
shared_counts[warp_id][i] = 0;
|
shared_counts[warp_id * experts_per_warp + i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
int expert_id = topk_ids[i];
|
int expert_id = topk_ids[i];
|
||||||
int warp_idx = expert_id / experts_per_warp;
|
int warp_idx = expert_id / experts_per_warp;
|
||||||
int expert_offset = expert_id % experts_per_warp;
|
int expert_offset = expert_id % experts_per_warp;
|
||||||
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
|
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
int expert_count = 0;
|
int expert_count = 0;
|
||||||
int warp_idx = (i - 1) / experts_per_warp;
|
int warp_idx = (i - 1) / experts_per_warp;
|
||||||
int expert_offset = (i - 1) % experts_per_warp;
|
int expert_offset = (i - 1) % experts_per_warp;
|
||||||
expert_count = shared_counts[warp_idx][expert_offset];
|
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||||
|
|
||||||
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||||
}
|
}
|
||||||
@@ -108,16 +108,18 @@ void moe_align_block_size(
|
|||||||
torch::Tensor token_cnts_buffer,
|
torch::Tensor token_cnts_buffer,
|
||||||
torch::Tensor cumsum_buffer) {
|
torch::Tensor cumsum_buffer) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
|
TORCH_CHECK(num_experts % WARP_SIZE == 0);
|
||||||
|
int experts_per_warp = num_experts / WARP_SIZE;
|
||||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||||
align_kernel<<<1, 1024, 0, stream>>>(
|
size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
|
||||||
|
align_kernel<<<1, 1024, shared_mem_size, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
experts_ids.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
num_experts,
|
num_experts,
|
||||||
|
experts_per_warp,
|
||||||
block_size,
|
block_size,
|
||||||
topk_ids.numel(),
|
topk_ids.numel(),
|
||||||
cumsum_buffer.data_ptr<int32_t>());
|
cumsum_buffer.data_ptr<int32_t>());
|
||||||
|
|||||||
@@ -138,18 +138,20 @@ def moe_align_block_size_triton(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"block_size,num_tokens,topk",
|
"block_size,num_tokens,topk,num_experts",
|
||||||
list(
|
list(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
[32, 64, 128, 256], # block_size
|
[32, 64, 128, 256], # block_size
|
||||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
||||||
[1, 2, 4, 8, 16, 32, 64], # topk
|
[1, 2, 4, 8, 16, 32, 64], # topk
|
||||||
|
[64, 160, 256], # num_experts
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_moe_align_block_size_compare_implementations(block_size, num_tokens, topk):
|
def test_moe_align_block_size_compare_implementations(
|
||||||
|
block_size, num_tokens, topk, num_experts
|
||||||
|
):
|
||||||
# For DeepSeek V3, we have 256 experts
|
# For DeepSeek V3, we have 256 experts
|
||||||
num_experts = 256
|
|
||||||
|
|
||||||
topk_ids = torch.stack(
|
topk_ids = torch.stack(
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user