feat: support ep size < 32 for sgl kernel (#4348)
This commit is contained in:
@@ -196,14 +196,21 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
|||||||
expert_ids_triton,
|
expert_ids_triton,
|
||||||
num_tokens_post_pad_triton,
|
num_tokens_post_pad_triton,
|
||||||
)
|
)
|
||||||
ops.moe_align_block_size(
|
|
||||||
topk_ids,
|
try:
|
||||||
num_experts,
|
ops.moe_align_block_size(
|
||||||
block_size,
|
topk_ids,
|
||||||
sorted_ids_vllm,
|
num_experts,
|
||||||
expert_ids_vllm,
|
block_size,
|
||||||
num_tokens_post_pad_vllm,
|
sorted_ids_vllm,
|
||||||
)
|
expert_ids_vllm,
|
||||||
|
num_tokens_post_pad_vllm,
|
||||||
|
)
|
||||||
|
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||||
|
vllm_works = True
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||||
|
vllm_works = False
|
||||||
|
|
||||||
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
||||||
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
|
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
|
||||||
@@ -216,20 +223,26 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
|||||||
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||||
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
||||||
|
|
||||||
if torch.allclose(expert_ids_cuda, expert_ids_vllm) and torch.allclose(
|
if (
|
||||||
num_tokens_post_pad_cuda, num_tokens_post_pad_vllm
|
vllm_works
|
||||||
|
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
|
||||||
|
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
|
||||||
):
|
):
|
||||||
print("✅ SGL and VLLM implementations match")
|
print("✅ SGL and VLLM implementations match")
|
||||||
else:
|
else:
|
||||||
print("❌ SGL and VLLM implementations do not match")
|
if not vllm_works:
|
||||||
print("SGL expert_ids:", expert_ids_cuda)
|
print("⚠️ VLLM comparison skipped due to failure")
|
||||||
print("VLLM expert_ids:", expert_ids_vllm)
|
else:
|
||||||
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
print("❌ SGL and VLLM implementations do not match")
|
||||||
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
print("SGL expert_ids:", expert_ids_cuda)
|
||||||
|
print("VLLM expert_ids:", expert_ids_vllm)
|
||||||
|
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||||
|
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
||||||
|
|
||||||
|
|
||||||
|
# Test range
|
||||||
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||||
num_experts_range = [32, 64, 128, 256]
|
num_experts_range = [8, 32, 64, 128, 256]
|
||||||
topk_range = [2, 4, 8]
|
topk_range = [2, 4, 8]
|
||||||
|
|
||||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||||
@@ -316,17 +329,22 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
else: # vllm
|
else: # vllm
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
try:
|
||||||
lambda: ops.moe_align_block_size(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
topk_ids,
|
lambda: ops.moe_align_block_size(
|
||||||
num_experts,
|
topk_ids,
|
||||||
block_size,
|
num_experts,
|
||||||
sorted_ids.clone(),
|
block_size,
|
||||||
expert_ids.clone(),
|
sorted_ids.clone(),
|
||||||
num_tokens_post_pad.clone(),
|
expert_ids.clone(),
|
||||||
),
|
num_tokens_post_pad.clone(),
|
||||||
quantiles=quantiles,
|
),
|
||||||
)
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}")
|
||||||
|
# Return extreme values to indicate failure in the chart
|
||||||
|
return float("inf"), float("inf"), float("inf")
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
@@ -343,7 +361,7 @@ if __name__ == "__main__":
|
|||||||
"--num_experts",
|
"--num_experts",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
choices=[8, 64, 128, 256],
|
choices=[8, 16, 32, 64, 128, 256],
|
||||||
help="Number of experts for benchmark",
|
help="Number of experts for benchmark",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -353,8 +371,15 @@ if __name__ == "__main__":
|
|||||||
choices=[2, 4, 8],
|
choices=[2, 4, 8],
|
||||||
help="Top-k value for benchmark",
|
help="Top-k value for benchmark",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_full_benchmark",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run the calculate_diff function, skip full benchmarking",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||||
|
|
||||||
benchmark.run(print_data=True)
|
if not args.skip_full_benchmark:
|
||||||
|
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
|
||||||
|
benchmark.run(print_data=True)
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
int32_t* __restrict__ expert_ids,
|
int32_t* __restrict__ expert_ids,
|
||||||
int32_t* __restrict__ total_tokens_post_pad,
|
int32_t* __restrict__ total_tokens_post_pad,
|
||||||
int32_t num_experts,
|
int32_t num_experts,
|
||||||
|
int32_t padded_num_experts,
|
||||||
int32_t experts_per_warp,
|
int32_t experts_per_warp,
|
||||||
int32_t block_size,
|
int32_t block_size,
|
||||||
size_t numel,
|
size_t numel,
|
||||||
@@ -57,7 +58,7 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
const int my_expert_start = warp_id * experts_per_warp;
|
const int my_expert_start = warp_id * experts_per_warp;
|
||||||
|
|
||||||
for (int i = 0; i < experts_per_warp; ++i) {
|
for (int i = 0; i < experts_per_warp; ++i) {
|
||||||
if (my_expert_start + i < num_experts) {
|
if (my_expert_start + i < padded_num_experts) {
|
||||||
shared_counts[warp_id * experts_per_warp + i] = 0;
|
shared_counts[warp_id * experts_per_warp + i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -108,23 +109,44 @@ void moe_align_block_size(
|
|||||||
torch::Tensor token_cnts_buffer,
|
torch::Tensor token_cnts_buffer,
|
||||||
torch::Tensor cumsum_buffer) {
|
torch::Tensor cumsum_buffer) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
TORCH_CHECK(num_experts % WARP_SIZE == 0);
|
|
||||||
int experts_per_warp = num_experts / WARP_SIZE;
|
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
|
||||||
|
int experts_per_warp;
|
||||||
|
int threads;
|
||||||
|
|
||||||
|
if (num_experts <= 8) {
|
||||||
|
experts_per_warp = 8;
|
||||||
|
threads = 256;
|
||||||
|
} else if (num_experts <= 16) {
|
||||||
|
experts_per_warp = 16;
|
||||||
|
threads = 512;
|
||||||
|
} else {
|
||||||
|
experts_per_warp = WARP_SIZE;
|
||||||
|
threads = 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
|
||||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||||
size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
|
|
||||||
align_kernel<<<1, 1024, shared_mem_size, stream>>>(
|
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||||
|
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
||||||
|
|
||||||
|
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
experts_ids.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
num_experts,
|
num_experts,
|
||||||
|
padded_num_experts,
|
||||||
experts_per_warp,
|
experts_per_warp,
|
||||||
block_size,
|
block_size,
|
||||||
topk_ids.numel(),
|
topk_ids.numel(),
|
||||||
cumsum_buffer.data_ptr<int32_t>());
|
cumsum_buffer.data_ptr<int32_t>());
|
||||||
|
|
||||||
const int block_threads = 256;
|
const int block_threads = std::min(256, (int)threads);
|
||||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||||
const int max_blocks = 65535;
|
const int max_blocks = 65535;
|
||||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||||
|
|||||||
Reference in New Issue
Block a user