clean moe align block kernel code and add acc test (#3332)
This commit is contained in:
@@ -310,4 +310,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
calculate_diff(batch_size=4, seq_len=1024)
|
calculate_diff(batch_size=4, seq_len=1024)
|
||||||
|
|
||||||
benchmark.run(print_data=True, save_path=args.save_path)
|
benchmark.run(print_data=True)
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
@@ -22,32 +20,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <THC/THCAtomics.cuh>
|
#include <THC/THCAtomics.cuh>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
|
||||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
|
||||||
|
|
||||||
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
|
||||||
|
|
||||||
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
|
||||||
|
|
||||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
|
||||||
return row * total_col + col;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||||
int32_t block_size, size_t numel, int32_t* cumsum) {
|
int32_t block_size, size_t numel, int32_t* cumsum) {
|
||||||
__shared__ int32_t shared_counts[32][8];
|
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||||
__shared__ int32_t local_offsets[256];
|
__shared__ int32_t local_offsets[256];
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
@@ -96,6 +77,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
|
||||||
|
// If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
|
||||||
|
// kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
|
||||||
|
// illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
|
||||||
|
// results in the same issue, and a correct solution has not yet been found.
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
int32_t expert_id = topk_ids[i];
|
int32_t expert_id = topk_ids[i];
|
||||||
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
|
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
|
||||||
@@ -107,9 +93,11 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
|||||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
|
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
|
||||||
|
|
||||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
auto kernel = moe_align_block_size_kernel<scalar_t>;
|
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||||
kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -79,3 +79,15 @@ inline int getSMVersion() {
|
|||||||
return false; \
|
return false; \
|
||||||
} \
|
} \
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
||||||
|
|||||||
@@ -1,40 +1,176 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
from sgl_kernel import moe_align_block_size
|
from sgl_kernel import moe_align_block_size
|
||||||
|
|
||||||
|
|
||||||
def test_moe_align_block_size():
|
def ceil_div(a, b):
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage1(
|
||||||
|
topk_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
off_c = (pid + 1) * num_experts
|
||||||
|
|
||||||
|
for i in range(tokens_per_thread):
|
||||||
|
if start_idx + i < numel:
|
||||||
|
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||||
|
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage2(
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
last_cnt = 0
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||||
|
last_cnt = last_cnt + token_cnt
|
||||||
|
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage3(
|
||||||
|
total_tokens_post_pad_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
last_cumsum = 0
|
||||||
|
off_cnt = num_experts * num_experts
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||||
|
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||||
|
tl.store(cumsum_ptr + i, last_cumsum)
|
||||||
|
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage4(
|
||||||
|
topk_ids_ptr,
|
||||||
|
sorted_token_ids_ptr,
|
||||||
|
expert_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
start_idx = tl.load(cumsum_ptr + pid)
|
||||||
|
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||||
|
|
||||||
|
for i in range(start_idx, end_idx, block_size):
|
||||||
|
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||||
|
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
off_t = pid * num_experts
|
||||||
|
|
||||||
|
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||||
|
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||||
|
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||||
|
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_align_block_size_triton(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
block_size: int,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_pad: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
numel = topk_ids.numel()
|
||||||
|
grid = (num_experts,)
|
||||||
|
tokens_cnts = torch.zeros(
|
||||||
|
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
||||||
|
tokens_per_thread = ceil_div(numel, num_experts)
|
||||||
|
|
||||||
|
moe_align_block_size_stage1[grid](
|
||||||
|
topk_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage2[grid](
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage3[(1,)](
|
||||||
|
num_tokens_post_pad,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage4[grid](
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"block_size,num_tokens,topk",
|
||||||
|
list(
|
||||||
|
itertools.product(
|
||||||
|
[32, 64, 128, 256], # block_size
|
||||||
|
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
||||||
|
[1, 2, 4, 8, 16, 32, 64], # topk
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_moe_align_block_size_compare_implementations(block_size, num_tokens, topk):
|
||||||
# For DeepSeek V3, we have 256 experts
|
# For DeepSeek V3, we have 256 experts
|
||||||
num_experts = 256
|
num_experts = 256
|
||||||
|
|
||||||
# Test different combinations of block_size, num_tokens and topk
|
topk_ids = torch.stack(
|
||||||
for block_size in [32, 64, 128, 256]:
|
[
|
||||||
print(f"\nTesting block_size={block_size}")
|
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||||
for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
for _ in range(num_tokens)
|
||||||
for topk in [1, 2, 4, 8, 16, 32, 64]:
|
]
|
||||||
print(
|
|
||||||
f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create random topk_ids with shape [num_tokens, topk]
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
topk_ids = torch.randint(
|
|
||||||
0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (
|
sorted_ids_cuda = torch.empty(
|
||||||
block_size - 1
|
|
||||||
)
|
|
||||||
sorted_ids = torch.empty(
|
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||||
expert_ids = torch.empty(
|
expert_ids_cuda = torch.zeros(
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty(
|
num_tokens_post_pad_cuda = torch.empty(
|
||||||
(1), dtype=torch.int32, device=topk_ids.device
|
(1), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
|
|
||||||
token_cnts_buffer = torch.empty(
|
token_cnts_buffer = torch.empty(
|
||||||
(num_experts + 1) * num_experts,
|
(num_experts + 1) * num_experts,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -44,24 +180,45 @@ def test_moe_align_block_size():
|
|||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||||
|
sorted_ids_triton.fill_(topk_ids.numel())
|
||||||
|
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||||
|
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||||
|
|
||||||
moe_align_block_size(
|
moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
sorted_ids,
|
sorted_ids_cuda,
|
||||||
expert_ids,
|
expert_ids_cuda,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad_cuda,
|
||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
print(
|
moe_align_block_size_triton(
|
||||||
f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids_triton,
|
||||||
|
expert_ids_triton,
|
||||||
|
num_tokens_post_pad_triton,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(expert_ids_cuda, expert_ids_triton), (
|
||||||
|
f"Expert IDs mismatch for block_size={block_size}, "
|
||||||
|
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||||
|
f"CUDA expert_ids: {expert_ids_cuda}\n"
|
||||||
|
f"Triton expert_ids: {expert_ids_triton}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), (
|
||||||
|
f"Num tokens post pad mismatch for block_size={block_size}, "
|
||||||
|
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||||
|
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
|
||||||
|
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
|
||||||
)
|
)
|
||||||
print(f"Error message: {str(e)}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_align_block_size()
|
pytest.main([__file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user