diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b037e7a92..5560cea67 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -76,6 +76,7 @@ suites = { TestFile("test_create_kvindices.py", 2), TestFile("test_hicache.py", 60), TestFile("test_hicache_mla.py", 90), + TestFile("test_fused_moe.py", 30), TestFile("test_triton_moe_channel_fp8_kernel.py", 25), ], "per-commit-2-gpu": [ diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index fcff74d62..9b6af04bc 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -3,7 +3,6 @@ import unittest import torch import torch.nn.functional as F from tqdm import tqdm -from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe @@ -45,7 +44,18 @@ class TestFusedMOE(CustomTestCase): else: return 1e-2, 1e-2 # Default values for other types - def torch_naive_moe(self, a, w1, w2, score, topk): + def torch_naive_moe( + self, + a, + w1, + w2, + score, + topk, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, + ): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) @@ -53,12 +63,30 @@ class TestFusedMOE(CustomTestCase): topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) - for i in range(w1.shape[0]): + + if w1.dtype == torch.float8_e4m3fn: + w1_compute = w1.to(a.dtype) + w2_compute = w2.to(a.dtype) + + if w1_scale is not None: + w1_compute = (w1_compute * w1_scale.view(-1, 1, 1)).to(a.dtype) + if w2_scale is not None: + w2_compute = (w2_compute * w2_scale.view(-1, 1, 1)).to(a.dtype) + if a1_scale is not None: + a = (a * a1_scale).to(a.dtype) + if a2_scale is not None: + a = (a * a2_scale).to(a.dtype) + else: + w1_compute = w1 + w2_compute = w2 + + for i in range(w1_compute.shape[0]): mask = topk_ids == i if mask.sum(): - out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[ - i - ].transpose(0, 1) + out[mask] = SiluAndMul()( + a[mask] @ w1_compute[i].transpose(0, 1) + ) @ w2_compute[i].transpose(0, 1) + return ( out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) ).sum(dim=1) @@ -98,21 +126,12 @@ class TestFusedMOE(CustomTestCase): a2_scale=a2_scale, ) - vllm_output = fused_moe_vllm( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + torch_output = self.torch_naive_moe( + a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale + ) + torch.testing.assert_close( + sglang_output, torch_output, rtol=rtol, atol=atol ) - - torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol) else: a = self.create_random_cuda_tensor((m, k), dtype) @@ -127,8 +146,8 @@ class TestFusedMOE(CustomTestCase): ) def test_various_configurations(self): - m_values = [1, 33, 64, 222, 1024 * 128] - n_values = [128, 1024, 2048] + m_values = [1, 33, 64, 222] + n_values = [128, 1024] k_values = [128, 511, 1024] dtypes = [torch.float16, torch.bfloat16] fp8_modes = [False, True] @@ -171,6 +190,7 @@ class TestFusedMOE(CustomTestCase): dtype, use_fp8_w8a8=use_fp8_w8a8, ) + torch.cuda.empty_cache() pbar.update(1)