Remove the vllm dependency from the moe_align function (#4164)
Co-authored-by: Hongbosherlock <hongbosherlock@gmail.com>
This commit is contained in:
@@ -47,18 +47,18 @@ __global__ void moe_align_block_size_kernel(
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t experts_per_warp,
|
||||
int32_t block_size,
|
||||
size_t numel,
|
||||
int32_t* __restrict__ cumsum) {
|
||||
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||
extern __shared__ int32_t shared_counts[];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int experts_per_warp = 8;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
for (int i = 0; i < experts_per_warp; ++i) {
|
||||
if (my_expert_start + i < num_experts) {
|
||||
shared_counts[warp_id][i] = 0;
|
||||
shared_counts[warp_id * experts_per_warp + i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ __global__ void moe_align_block_size_kernel(
|
||||
int expert_id = topk_ids[i];
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
|
||||
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
|
||||
int expert_count = 0;
|
||||
int warp_idx = (i - 1) / experts_per_warp;
|
||||
int expert_offset = (i - 1) % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx][expert_offset];
|
||||
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||
|
||||
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
@@ -108,16 +108,18 @@ void moe_align_block_size(
|
||||
torch::Tensor token_cnts_buffer,
|
||||
torch::Tensor cumsum_buffer) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
|
||||
|
||||
TORCH_CHECK(num_experts % WARP_SIZE == 0);
|
||||
int experts_per_warp = num_experts / WARP_SIZE;
|
||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
align_kernel<<<1, 1024, 0, stream>>>(
|
||||
size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
|
||||
align_kernel<<<1, 1024, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
experts_per_warp,
|
||||
block_size,
|
||||
topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
|
||||
@@ -138,18 +138,20 @@ def moe_align_block_size_triton(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_size,num_tokens,topk",
|
||||
"block_size,num_tokens,topk,num_experts",
|
||||
list(
|
||||
itertools.product(
|
||||
[32, 64, 128, 256], # block_size
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
||||
[1, 2, 4, 8, 16, 32, 64], # topk
|
||||
[64, 160, 256], # num_experts
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_moe_align_block_size_compare_implementations(block_size, num_tokens, topk):
|
||||
def test_moe_align_block_size_compare_implementations(
|
||||
block_size, num_tokens, topk, num_experts
|
||||
):
|
||||
# For DeepSeek V3, we have 256 experts
|
||||
num_experts = 256
|
||||
|
||||
topk_ids = torch.stack(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user