Revert "[MOE] enable efficient moe_alignment multi-blocks execution (3x~6x)" (#3982)
This commit is contained in:
@@ -99,12 +99,13 @@ def moe_align_block_size_triton(
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
tokens_cnts: torch.Tensor,
|
||||
cumsum: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts,)
|
||||
|
||||
tokens_cnts = torch.zeros(
|
||||
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
@@ -138,18 +139,11 @@ def moe_align_block_size_triton(
|
||||
)
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, num_experts):
|
||||
num_experts = num_experts
|
||||
def calculate_diff(batch_size, seq_len):
|
||||
num_experts = 256
|
||||
block_size = 128
|
||||
topk = 8
|
||||
|
||||
assert batch_size >= 1
|
||||
assert seq_len >= 1
|
||||
assert num_experts >= 4
|
||||
|
||||
if topk > num_experts:
|
||||
topk = num_experts
|
||||
|
||||
topk_ids = torch.stack(
|
||||
[
|
||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||
@@ -181,13 +175,6 @@ def calculate_diff(batch_size, seq_len, num_experts):
|
||||
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||
|
||||
token_cnts_buffer_triton = torch.zeros(
|
||||
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer_triton = torch.zeros(
|
||||
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
# compare the performance of cuda and triton implementation
|
||||
moe_align_block_size(
|
||||
topk_ids,
|
||||
@@ -206,27 +193,14 @@ def calculate_diff(batch_size, seq_len, num_experts):
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
num_tokens_post_pad_triton,
|
||||
token_cnts_buffer_triton,
|
||||
cumsum_buffer_triton,
|
||||
)
|
||||
|
||||
sorted_ids_cuda_snapshot = sorted_ids_cuda[: cumsum_buffer[1]].sort().values
|
||||
sorted_ids_triton_snapshot = sorted_ids_triton[: cumsum_buffer[1]].sort().values
|
||||
|
||||
if (
|
||||
torch.allclose(expert_ids_cuda, expert_ids_triton)
|
||||
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton)
|
||||
and torch.allclose(sorted_ids_cuda_snapshot, sorted_ids_triton_snapshot)
|
||||
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
||||
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
|
||||
):
|
||||
print(
|
||||
"✅ CUDA and Triton implementations match : num_tokens={}, num_experts={}".format(
|
||||
batch_size * seq_len, num_experts
|
||||
)
|
||||
)
|
||||
print("✅ CUDA and Triton implementations match")
|
||||
else:
|
||||
print("❌ CUDA and Triton implementations do not match")
|
||||
print("CUDA sorted ids:", sorted_ids_cuda_snapshot)
|
||||
print("Triton sorted ids:", sorted_ids_triton_snapshot)
|
||||
print("CUDA expert_ids:", expert_ids_cuda)
|
||||
print("Triton expert_ids:", expert_ids_triton)
|
||||
print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||
@@ -282,7 +256,7 @@ def benchmark(batch_size, seq_len, provider):
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.zeros(
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
@@ -293,37 +267,34 @@ def benchmark(batch_size, seq_len, provider):
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
# Warm up
|
||||
api_func = (
|
||||
moe_align_block_size if provider == "cuda" else moe_align_block_size_triton
|
||||
)
|
||||
for _ in range(10):
|
||||
api_func(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer.clone(),
|
||||
cumsum_buffer.clone(),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: api_func(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer.clone(),
|
||||
cumsum_buffer.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cuda":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
@@ -335,22 +306,8 @@ if __name__ == "__main__":
|
||||
default="./configs/benchmark_ops/moe_align_blocks/",
|
||||
help="Path to save moe align benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify",
|
||||
action="store_true",
|
||||
help="verify kernel",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verify:
|
||||
num_experts_range = [2**i for i in range(3, 9)]
|
||||
calculate_diff(batch_size=4, seq_len=1024)
|
||||
|
||||
configs = list(
|
||||
itertools.product(batch_size_range, seq_length_range, num_experts_range)
|
||||
)
|
||||
|
||||
for bs, seq, num_experts in configs:
|
||||
calculate_diff(batch_size=bs, seq_len=seq, num_experts=num_experts)
|
||||
else:
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
Reference in New Issue
Block a user