[chore] Remove unused ep_moe cuda kernels (#9956)
This commit is contained in:
@@ -1,164 +0,0 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import ep_moe_post_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
down_output = torch.randn(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Ensure src2dst has no duplicate destinations to avoid race conditions
|
||||
total_tokens = batch_size * topk
|
||||
dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32)
|
||||
src2dst = dst_indices.view(batch_size, topk)
|
||||
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
|
||||
|
||||
return down_output, src2dst, topk_ids, topk_weights
|
||||
|
||||
|
||||
def run_cuda_kernel(
|
||||
down_output: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
src2dst: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
topk: int,
|
||||
):
|
||||
ep_moe_post_reorder(
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def run_triton_kernel(
|
||||
down_output: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
src2dst: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
topk: int,
|
||||
hidden_size: int,
|
||||
):
|
||||
batch_size = down_output.size(0)
|
||||
block_size = 512
|
||||
|
||||
post_reorder_triton_kernel[(batch_size,)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
0,
|
||||
block_size,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def assert_close(a, b):
|
||||
a32, b32 = a.float(), b.float()
|
||||
if a.dtype is torch.float16:
|
||||
torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-2)
|
||||
elif a.dtype is torch.bfloat16:
|
||||
torch.testing.assert_close(a32, b32, rtol=1e-4, atol=1e-1)
|
||||
else:
|
||||
torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,hidden_size,topk",
|
||||
list(itertools.product([32, 64], [128, 256, 512], [2, 4, 8])),
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_ep_moe_post_reorder_vs_triton(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
start_expert_id = 0
|
||||
end_expert_id = 15
|
||||
|
||||
(
|
||||
down_output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
) = create_test_tensors(
|
||||
batch_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
output_cuda = torch.empty(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
output_triton = torch.empty(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
cuda_output = run_cuda_kernel(
|
||||
down_output,
|
||||
output_cuda,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
|
||||
triton_output = run_triton_kernel(
|
||||
down_output,
|
||||
output_triton,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
)
|
||||
|
||||
assert_close(cuda_output, triton_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -1,181 +0,0 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import ep_moe_pre_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
use_per_token_if_dynamic: bool = True,
|
||||
):
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Ensure src2dst has no duplicate destinations to avoid race conditions
|
||||
total_tokens = batch_size * topk
|
||||
dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32)
|
||||
src2dst = dst_indices.view(batch_size, topk)
|
||||
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if use_per_token_if_dynamic:
|
||||
a1_scales = (
|
||||
torch.rand(batch_size, dtype=torch.float32, device=device) * 0.8 + 0.2
|
||||
)
|
||||
else:
|
||||
a1_scales = (
|
||||
torch.rand(
|
||||
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
|
||||
)
|
||||
* 0.8
|
||||
+ 0.2
|
||||
)
|
||||
|
||||
return input_tensor, src2dst, topk_ids, a1_scales
|
||||
|
||||
|
||||
def run_cuda_kernel(
|
||||
input_tensor: torch.Tensor,
|
||||
gateup_input: torch.Tensor,
|
||||
src2dst: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a1_scales: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
topk: int,
|
||||
use_per_token_if_dynamic: bool,
|
||||
):
|
||||
ep_moe_pre_reorder(
|
||||
input_tensor,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
a1_scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
use_per_token_if_dynamic,
|
||||
)
|
||||
return gateup_input
|
||||
|
||||
|
||||
def run_triton_kernel(
|
||||
input_tensor: torch.Tensor,
|
||||
gateup_input: torch.Tensor,
|
||||
src2dst: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a1_scales: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
topk: int,
|
||||
hidden_size: int,
|
||||
use_per_token_if_dynamic: bool,
|
||||
):
|
||||
batch_size = input_tensor.size(0)
|
||||
block_size = 512
|
||||
|
||||
pre_reorder_triton_kernel[(batch_size,)](
|
||||
input_tensor,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
a1_scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
block_size,
|
||||
use_per_token_if_dynamic,
|
||||
)
|
||||
return gateup_input
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,hidden_size,topk",
|
||||
list(itertools.product([32, 64, 128], [512, 1024, 2048], [4, 8])),
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("use_per_token_if_dynamic", [True, False])
|
||||
def test_ep_moe_pre_reorder_vs_triton(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_per_token_if_dynamic: bool,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
start_expert_id = 0
|
||||
end_expert_id = 15
|
||||
|
||||
(
|
||||
input_tensor,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
a1_scales,
|
||||
) = create_test_tensors(
|
||||
batch_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
dtype,
|
||||
device,
|
||||
use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
gateup_input_cuda = torch.empty(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
gateup_input_triton = torch.empty(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
cuda_output = run_cuda_kernel(
|
||||
input_tensor,
|
||||
gateup_input_cuda,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
a1_scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
triton_output = run_triton_kernel(
|
||||
input_tensor,
|
||||
gateup_input_triton,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
a1_scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
cuda_output.float(),
|
||||
triton_output.float(),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -1,142 +0,0 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import ep_moe_silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
total_tokens: int,
|
||||
hidden_size: int,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
gateup_output = torch.randn(total_tokens, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
reorder_topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(total_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_experts = end_expert_id - start_expert_id + 1
|
||||
scales = torch.rand(num_experts, dtype=torch.float32, device=device) * 0.8 + 0.5
|
||||
|
||||
half_hidden = hidden_size // 2
|
||||
down_input = torch.empty(total_tokens, half_hidden, dtype=dtype, device=device)
|
||||
|
||||
return gateup_output, down_input, reorder_topk_ids, scales
|
||||
|
||||
|
||||
def run_cuda_kernel(
|
||||
gateup_output: torch.Tensor,
|
||||
down_input: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
):
|
||||
ep_moe_silu_and_mul(
|
||||
gateup_output,
|
||||
down_input,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
)
|
||||
return down_input
|
||||
|
||||
|
||||
def run_triton_kernel(
|
||||
gateup_output: torch.Tensor,
|
||||
down_input: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
hidden_size: int,
|
||||
):
|
||||
total_tokens = gateup_output.size(0)
|
||||
block_size = 512
|
||||
|
||||
silu_and_mul_triton_kernel[(total_tokens,)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
hidden_size,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
block_size,
|
||||
)
|
||||
return down_input
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"total_tokens,hidden_size",
|
||||
list(itertools.product([32, 256, 1024], [128, 256, 512])),
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_ep_moe_silu_and_mul_vs_triton(
|
||||
total_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
start_expert_id = 0
|
||||
end_expert_id = 15
|
||||
|
||||
(
|
||||
gateup_output,
|
||||
_,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
) = create_test_tensors(
|
||||
total_tokens,
|
||||
hidden_size,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
down_input_cuda = torch.empty(
|
||||
total_tokens, hidden_size // 2, dtype=dtype, device=device
|
||||
)
|
||||
down_input_triton = torch.empty_like(down_input_cuda)
|
||||
|
||||
cuda_output = run_cuda_kernel(
|
||||
gateup_output,
|
||||
down_input_cuda,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
)
|
||||
|
||||
triton_output = run_triton_kernel(
|
||||
gateup_output,
|
||||
down_input_triton,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
hidden_size,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
cuda_output,
|
||||
triton_output,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user