optimize test_fused_moe style (#3268)
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
|
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
@@ -11,6 +13,37 @@ class TestFusedMOE(unittest.TestCase):
|
|||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
TOP_KS = [2, 6]
|
TOP_KS = [2, 6]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01):
|
||||||
|
"""Create a random CUDA tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Tensor shape
|
||||||
|
dtype: Data type
|
||||||
|
mean: Mean value
|
||||||
|
std: Standard deviation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Randomly initialized CUDA tensor
|
||||||
|
"""
|
||||||
|
return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std)
|
||||||
|
|
||||||
|
def get_tolerance(self, dtype):
|
||||||
|
"""Get tolerance values for different data types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: Data type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (relative tolerance, absolute tolerance)
|
||||||
|
"""
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return 1e-3, 1e-5
|
||||||
|
elif dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
return 1e-1, 1e-2
|
||||||
|
else:
|
||||||
|
return 1e-2, 1e-2 # Default values for other types
|
||||||
|
|
||||||
def torch_naive_moe(self, a, w1, w2, score, topk):
|
def torch_naive_moe(self, a, w1, w2, score, topk):
|
||||||
B, D = a.shape
|
B, D = a.shape
|
||||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
@@ -30,23 +63,25 @@ class TestFusedMOE(unittest.TestCase):
|
|||||||
).sum(dim=1)
|
).sum(dim=1)
|
||||||
|
|
||||||
def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
|
def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
|
||||||
|
rtol, atol = self.get_tolerance(dtype)
|
||||||
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
if not (capability[0] >= 9 or capability == (8, 9)):
|
if not (capability[0] >= 9 or capability == (8, 9)):
|
||||||
return
|
return
|
||||||
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = self.create_random_cuda_tensor((e, k, n), dtype)
|
||||||
w1 = w1.to(torch.float8_e4m3fn)
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
w2 = w2.to(torch.float8_e4m3fn)
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||||
|
|
||||||
w1_scale = torch.randn(e, dtype=torch.float32, device="cuda")
|
w1_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||||
w2_scale = torch.randn(e, dtype=torch.float32, device="cuda")
|
w2_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||||
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
a1_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||||
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
a2_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||||
|
|
||||||
sglang_output = fused_moe(
|
sglang_output = fused_moe(
|
||||||
a,
|
a,
|
||||||
@@ -76,17 +111,19 @@ class TestFusedMOE(unittest.TestCase):
|
|||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = self.create_random_cuda_tensor((e, k, n), dtype)
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||||
|
|
||||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||||
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(
|
||||||
|
triton_output, torch_output, rtol=rtol, atol=atol
|
||||||
|
)
|
||||||
|
|
||||||
def test_various_configurations(self):
|
def test_various_configurations(self):
|
||||||
m_values = [1, 33, 64, 222, 1024 * 128]
|
m_values = [1, 33, 64, 222, 1024 * 128]
|
||||||
@@ -95,31 +132,45 @@ class TestFusedMOE(unittest.TestCase):
|
|||||||
dtypes = [torch.float16, torch.bfloat16]
|
dtypes = [torch.float16, torch.bfloat16]
|
||||||
fp8_modes = [False, True]
|
fp8_modes = [False, True]
|
||||||
|
|
||||||
for m in m_values:
|
# Calculate total number of tests
|
||||||
for n in n_values:
|
total_tests = (
|
||||||
for k in k_values:
|
len(m_values)
|
||||||
for e in self.NUM_EXPERTS:
|
* len(n_values)
|
||||||
for topk in self.TOP_KS:
|
* len(k_values)
|
||||||
for dtype in dtypes:
|
* len(self.NUM_EXPERTS)
|
||||||
for use_fp8_w8a8 in fp8_modes:
|
* len(self.TOP_KS)
|
||||||
with self.subTest(
|
* len(dtypes)
|
||||||
m=m,
|
* len(fp8_modes)
|
||||||
n=n,
|
)
|
||||||
k=k,
|
|
||||||
e=e,
|
# Create progress bar
|
||||||
topk=topk,
|
with tqdm(total=total_tests, desc="Running MoE tests") as pbar:
|
||||||
dtype=dtype,
|
for m in m_values:
|
||||||
fp8=use_fp8_w8a8,
|
for n in n_values:
|
||||||
):
|
for k in k_values:
|
||||||
self._test_case(
|
for e in self.NUM_EXPERTS:
|
||||||
m,
|
for topk in self.TOP_KS:
|
||||||
n,
|
for dtype in dtypes:
|
||||||
k,
|
for use_fp8_w8a8 in fp8_modes:
|
||||||
e,
|
with self.subTest(
|
||||||
topk,
|
m=m,
|
||||||
dtype,
|
n=n,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
k=k,
|
||||||
)
|
e=e,
|
||||||
|
topk=topk,
|
||||||
|
dtype=dtype,
|
||||||
|
fp8=use_fp8_w8a8,
|
||||||
|
):
|
||||||
|
self._test_case(
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
e,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user