diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 80aeab257..6534a4a60 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -1,6 +1,8 @@ import unittest import torch +import torch.nn.functional as F +from tqdm import tqdm from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm from sglang.srt.layers.activation import SiluAndMul @@ -11,6 +13,37 @@ class TestFusedMOE(unittest.TestCase): NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] + @staticmethod + def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): + """Create a random CUDA tensor + + Args: + shape: Tensor shape + dtype: Data type + mean: Mean value + std: Standard deviation + + Returns: + torch.Tensor: Randomly initialized CUDA tensor + """ + return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std) + + def get_tolerance(self, dtype): + """Get tolerance values for different data types + + Args: + dtype: Data type + + Returns: + tuple: (relative tolerance, absolute tolerance) + """ + if dtype == torch.float32: + return 1e-3, 1e-5 + elif dtype in [torch.float16, torch.bfloat16]: + return 1e-1, 1e-2 + else: + return 1e-2, 1e-2 # Default values for other types + def torch_naive_moe(self, a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -30,23 +63,25 @@ class TestFusedMOE(unittest.TestCase): ).sum(dim=1) def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): + rtol, atol = self.get_tolerance(dtype) + if use_fp8_w8a8: # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 capability = torch.cuda.get_device_capability() if not (capability[0] >= 9 or capability == (8, 9)): return - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + a = self.create_random_cuda_tensor((m, k), dtype) + w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_cuda_tensor((e, k, n), dtype) w1 = w1.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn) - score = torch.randn((m, e), device="cuda", dtype=dtype) + score = self.create_random_cuda_tensor((m, e), dtype) - w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") - w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") - a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") - a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + w1_scale = self.create_random_cuda_tensor(e, torch.float32) + w2_scale = self.create_random_cuda_tensor(e, torch.float32) + a1_scale = self.create_random_cuda_tensor(1, torch.float32) + a2_scale = self.create_random_cuda_tensor(1, torch.float32) sglang_output = fused_moe( a, @@ -76,17 +111,19 @@ class TestFusedMOE(unittest.TestCase): a2_scale=a2_scale, ) - torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) + torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol) else: - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + a = self.create_random_cuda_tensor((m, k), dtype) + w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_cuda_tensor((e, k, n), dtype) + score = self.create_random_cuda_tensor((m, e), dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = self.torch_naive_moe(a, w1, w2, score, topk) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close( + triton_output, torch_output, rtol=rtol, atol=atol + ) def test_various_configurations(self): m_values = [1, 33, 64, 222, 1024 * 128] @@ -95,31 +132,45 @@ class TestFusedMOE(unittest.TestCase): dtypes = [torch.float16, torch.bfloat16] fp8_modes = [False, True] - for m in m_values: - for n in n_values: - for k in k_values: - for e in self.NUM_EXPERTS: - for topk in self.TOP_KS: - for dtype in dtypes: - for use_fp8_w8a8 in fp8_modes: - with self.subTest( - m=m, - n=n, - k=k, - e=e, - topk=topk, - dtype=dtype, - fp8=use_fp8_w8a8, - ): - self._test_case( - m, - n, - k, - e, - topk, - dtype, - use_fp8_w8a8=use_fp8_w8a8, - ) + # Calculate total number of tests + total_tests = ( + len(m_values) + * len(n_values) + * len(k_values) + * len(self.NUM_EXPERTS) + * len(self.TOP_KS) + * len(dtypes) + * len(fp8_modes) + ) + + # Create progress bar + with tqdm(total=total_tests, desc="Running MoE tests") as pbar: + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + for use_fp8_w8a8 in fp8_modes: + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + fp8=use_fp8_w8a8, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + use_fp8_w8a8=use_fp8_w8a8, + ) + pbar.update(1) if __name__ == "__main__":